SS3M commited on
Commit
4eecfd0
·
verified ·
1 Parent(s): df31aca

Upload 4.2_2_clone_4.6's state dict

Browse files
.gitattributes CHANGED
@@ -94,3 +94,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
94
  4.1_entities_no_ce_4.2/logs/4.1_entities_no_ce_4.2_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
95
  4.2_add_span_rerank_branch_4.3/logs/4.2_add_span_rerank_branch_4.3_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
96
  4.2_add_span_rerank_branch_4.5/logs/4.2_add_span_rerank_branch_4.5_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
 
 
94
  4.1_entities_no_ce_4.2/logs/4.1_entities_no_ce_4.2_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
95
  4.2_add_span_rerank_branch_4.3/logs/4.2_add_span_rerank_branch_4.3_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
96
  4.2_add_span_rerank_branch_4.5/logs/4.2_add_span_rerank_branch_4.5_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
97
+ 4.2_2_clone_4.6/logs/4.2_2_clone_4.6_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
4.2_2_clone_4.6/4.2_2_clone_4.6.py ADDED
@@ -0,0 +1,2283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [code]
2
+ get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch] pytorch-crf')
3
+
4
+ # %% [code]
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
12
+ import torch.nn.functional as F
13
+ import albumentations as albu
14
+ from transformers import AutoTokenizer, AutoModel
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from positional_encodings.torch_encodings import PositionalEncoding1D
18
+ from torchcrf import CRF
19
+
20
+ from sklearn.metrics import f1_score
21
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
22
+ from scipy.spatial.transform import Rotation as R
23
+ from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold
24
+ from sklearn.metrics import precision_recall_fscore_support
25
+ from timm.utils import ModelEmaV3
26
+ import timm
27
+
28
+ import os
29
+ import gc
30
+ import json
31
+ from pathlib import Path
32
+ import pickle
33
+ from tqdm.auto import tqdm
34
+ import copy
35
+ import numpy as np
36
+ import pandas as pd
37
+ import polars as pl
38
+ from PIL import Image
39
+ import time
40
+ from tqdm import tqdm
41
+ from matplotlib import pyplot as plt
42
+ import seaborn as sns
43
+ from multiprocessing import Manager as MemoryManager
44
+ from functools import lru_cache
45
+ import shutil
46
+ import glob
47
+ import cv2
48
+ import random
49
+ import re
50
+ import joblib
51
+ import math
52
+ from huggingface_hub import HfApi, snapshot_download
53
+ import evaluate
54
+ from underthesea import word_tokenize as vi_tokenize_tool
55
+ import spacy
56
+ en_tokenize_tool = spacy.load("en_core_web_sm")
57
+ from collections import defaultdict, Counter
58
+
59
+ # %% [code]
60
+ # Global config
61
+ SEEDS = [26092004]
62
+ topk = 1
63
+ nfolds = 5
64
+ only_fold_idx = 0
65
+ test_only = 0
66
+ debug_only = 0
67
+
68
+ # Config thư mục
69
+ dataset = 'kltn/only_entities' # conll003, ontonotes, phoner, vietbio, vietmed, vimed, kltn/only_entities, kltn/raw
70
+ root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' ## Thư mục chứa file train, val, test
71
+ train_dir = f'{root_dir}'
72
+ # val_dir = f'{root_dir}/val'
73
+ test_dir = f'{root_dir}'
74
+
75
+ # Config checkpoints
76
+
77
+ # Config training
78
+ epochs = 18 if not debug_only else 2
79
+ batch_size = 32
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+ # # Thêm biến toàn cục nào đó vào đây
82
+ repo_name = 'SS3M/kltn-experiments'
83
+ state_dict_save_name = "4.2_2_clone_4.6"
84
+ checkpoints_dir = state_dict_save_name
85
+ pretrained_dir = "/kaggle/working"
86
+ os.makedirs(f'{checkpoints_dir}', exist_ok=True)
87
+
88
+ backbone_model_name = "bert-base-uncased" if dataset in ["conll003", "ontonotes"] else "vinai/phobert-base"
89
+ word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == dataset in ["conll003", "ontonotes"] else vi_tokenize_tool(text)
90
+ max_len_dict = {
91
+ 'kltn/raw': 256,
92
+ 'kltn/only_entities': 68,
93
+ 'conll003': 46,
94
+ 'ontonotes': 61,
95
+ 'phoner': 68,
96
+ 'vietbio': 125,
97
+ 'vietmed': 36,
98
+ 'vimed': 100,
99
+ }
100
+ zero_entities_rate_dict = {
101
+ 'kltn/raw': 1000,
102
+ 'kltn/only_entities': 0.2,
103
+ 'conll003': 1000, # mean keep all zero-entities samples
104
+ 'ontonotes': 1000,
105
+ 'phoner': 1000,
106
+ 'vietbio': 1000,
107
+ 'vietmed': 1000,
108
+ 'vimed': 1000,
109
+ }
110
+
111
+ max_len = max_len_dict[dataset]
112
+ max_n_parts = 1
113
+ max_span_len = 10
114
+ zero_entities_rate = zero_entities_rate_dict[dataset]
115
+
116
+ # Trainer
117
+ trainer_params = {
118
+ "training_time": "00:11:30:00",
119
+ "eval_mode": "max",
120
+ "topk": topk,
121
+ "save_name": state_dict_save_name,
122
+ "save_best": True,
123
+ "save_last": True,
124
+ "device": device,
125
+ "logging": True,
126
+ "logging_file": True,
127
+ "checkpoints_dir": checkpoints_dir,
128
+ "early_stopping": 30,
129
+ "eval_from_ratio": 0.4,
130
+ "eval_every": 1,
131
+ "schedule_in_step": False,
132
+ "use_ema": True,
133
+ "ema_from_ratio": 0.3,
134
+ "ema_decay": 0.9995,
135
+ "max_grad_norm": 200.0,
136
+ "return_best": True,
137
+ "return_last": True,
138
+ }
139
+
140
+ # Memory
141
+ train_memory_params = {
142
+ 'max_len': max_len,
143
+ 'max_n_parts': max_n_parts,
144
+ }
145
+ val_memory_params = {
146
+ 'max_len': max_len,
147
+ 'max_n_parts': max_n_parts,
148
+ }
149
+
150
+ # Data Loader
151
+ def seed_worker(worker_id):
152
+ worker_seed = torch.initial_seed() % 2**32
153
+ np.random.seed(worker_seed)
154
+ random.seed(worker_seed)
155
+
156
+ train_loader_params = {
157
+ 'batch_size': batch_size,
158
+ 'shuffle': True,
159
+ 'pin_memory':True,
160
+ 'num_workers': 2,
161
+ 'drop_last': False,
162
+ 'worker_init_fn': seed_worker,
163
+ 'persistent_workers': False,
164
+ }
165
+ val_loader_params = {
166
+ 'batch_size': batch_size,
167
+ 'shuffle': False,
168
+ 'pin_memory':True,
169
+ 'num_workers': 1,
170
+ 'drop_last': False,
171
+ 'worker_init_fn': seed_worker,
172
+ 'persistent_workers': False,
173
+ }
174
+
175
+ # Model
176
+ model_params = {
177
+ 'backbone_model_name': backbone_model_name,
178
+ 'max_r': 6,
179
+ }
180
+
181
+ # Loss Func
182
+ loss_func_params = {
183
+ 'lambda_span': 1.0,
184
+ 'lambda_margin': 1.0,
185
+ 'margin': 0.5,
186
+ }
187
+ eval_func_params = {}
188
+
189
+ # Optim
190
+ optim_params = {
191
+ 'name': 'AdamW',
192
+ 'lr': 1e-4,
193
+ 'weight_decay': 1e-4,
194
+ }
195
+ scheduler_params = {
196
+ 'name': 'CosineAnnealingLR',
197
+ 'T_max': 20, # Số epoch để hoàn thành một chu kỳ giảm LR
198
+ 'eta_min': 1e-6 # Learning rate nhỏ nhất trong chu kỳ
199
+ }
200
+
201
+ # %% [code]
202
+ def set_seed(seed=42):
203
+ random.seed(seed)
204
+ np.random.seed(seed)
205
+ torch.manual_seed(seed)
206
+ torch.cuda.manual_seed(seed)
207
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
208
+ torch.use_deterministic_algorithms(False)
209
+ torch.backends.cudnn.deterministic = True
210
+ torch.backends.cudnn.benchmark = False
211
+ os.environ['PYTHONHASHSEED'] = str(seed)
212
+
213
+ # %% [code]
214
+ class CustomLoss(nn.Module):
215
+ def __init__(
216
+ self,
217
+ lambda_margin=1.0,
218
+ lambda_span=1.0,
219
+ margin=1.0
220
+ ):
221
+ super().__init__()
222
+
223
+ self.lambda_margin = lambda_margin
224
+ self.lambda_span = lambda_span
225
+ self.margin = margin
226
+
227
+ def margin_loss(self, logits, labels):
228
+ """
229
+ logits: (N, C)
230
+ labels: (N,)
231
+ """
232
+
233
+ valid_mask = labels != -100
234
+
235
+ logits = logits[valid_mask]
236
+ labels = labels[valid_mask]
237
+
238
+ if len(labels) == 0:
239
+ return logits.new_tensor(0.0)
240
+
241
+ # positive logit
242
+ pos_logit = logits.gather(dim=1, index=labels.unsqueeze(-1)).squeeze(-1)
243
+
244
+ # hardest negative
245
+ neg_logits = logits.clone()
246
+ neg_logits.scatter_(1, labels.unsqueeze(-1), float("-inf"))
247
+ hardest_neg = neg_logits.max(dim=-1).values
248
+
249
+ loss = F.relu(self.margin - pos_logit + hardest_neg)
250
+ return loss.mean()
251
+
252
+ def span_loss(self, span_scores, labels):
253
+ """
254
+ span_scores: (B, N, M)
255
+ labels: (B, N, M)
256
+
257
+ labels:
258
+ -100 -> ignore
259
+ 0 -> negative
260
+ >0 -> positive
261
+ """
262
+
263
+ device = span_scores.device
264
+
265
+ B, N, M = span_scores.shape
266
+ if N == 0 or M == 0:
267
+ return span_scores.new_tensor(0.0)
268
+
269
+ valid_mask = labels != -100
270
+ pos_mask = labels > 0
271
+ neg_mask = labels == 0
272
+
273
+ pos_scores = span_scores.masked_fill(~pos_mask, float("-inf"))
274
+ best_pos = pos_scores.max(dim=-1).values # (B, N)
275
+
276
+ neg_scores = span_scores.masked_fill(~neg_mask, float("-inf"))
277
+ best_neg = neg_scores.max(dim=-1).values # (B, N)
278
+ no_neg = ~neg_mask.any(dim=-1)
279
+ best_neg = torch.where(no_neg, torch.zeros_like(best_neg), best_neg)
280
+
281
+ has_pos = pos_mask.any(dim=-1)
282
+ loss = F.relu(self.margin - best_pos + best_neg)
283
+ loss = loss[has_pos]
284
+ if loss.numel() == 0:
285
+ return span_scores.new_tensor(0.0)
286
+
287
+ return loss.mean()
288
+
289
+ def forward(
290
+ self,
291
+ start_logits, start_labels,
292
+ end_logits, end_labels,
293
+ span_scores, labels
294
+ ):
295
+
296
+ B, L, C = start_logits.shape
297
+
298
+ start_logits_flat = start_logits.view(B * L, C)
299
+ start_labels_flat = start_labels.view(-1)
300
+
301
+ end_logits_flat = end_logits.view(B * L, C)
302
+ end_labels_flat = end_labels.view(-1)
303
+
304
+ start_margin = self.margin_loss(start_logits_flat, start_labels_flat)
305
+ end_margin = self.margin_loss(end_logits_flat, end_labels_flat)
306
+ token_margin_loss = start_margin + end_margin
307
+
308
+ span_loss = self.span_loss(span_scores, labels)
309
+
310
+ total_loss = (
311
+ self.lambda_margin * token_margin_loss +
312
+ self.lambda_span * span_loss
313
+ )
314
+
315
+ return {
316
+ "total": total_loss,
317
+ "token_margin_loss": token_margin_loss,
318
+ "span_loss": span_loss,
319
+ "start_margin": start_margin,
320
+ "end_margin": end_margin,
321
+ }
322
+
323
+ # %% [code]
324
+ ## Viết eval_fn vào đây
325
+
326
+ # Bỏ hết eval_fn và trọng số vào đây
327
+ class CustomEvalFn(nn.Module):
328
+ def __init__(self):
329
+ super().__init__()
330
+
331
+ def compute_f1(self, tp, fp, fn):
332
+ precision = tp / (tp + fp + 1e-8)
333
+ recall = tp / (tp + fn + 1e-8)
334
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
335
+ return precision, recall, f1
336
+
337
+ def forward(self, pred, gold):
338
+ pred_set = set(pred)
339
+ gold_set = set(gold)
340
+
341
+ tp = len(pred_set & gold_set)
342
+ fp = len(pred_set - gold_set)
343
+ fn = len(gold_set - pred_set)
344
+
345
+ precision, recall, f1 = self.compute_f1(tp, fp, fn)
346
+
347
+ return {
348
+ f"precision": precision,
349
+ f"recall": recall,
350
+ f"f1": f1,
351
+ }
352
+
353
+ class SpanErrorAnalyzer:
354
+ def __init__(self, pad_token_id=0):
355
+ self.pad_token_id = pad_token_id
356
+
357
+ # ===== helper =====
358
+ def _to_set(self, data):
359
+ """
360
+ data: list of (b, tuple(ids))
361
+ -> dict[b] = set(tuple(ids))
362
+ """
363
+ res = defaultdict(set)
364
+ for b, ids in data:
365
+ ids = tuple([i for i in ids if i != self.pad_token_id])
366
+ if len(ids) > 0:
367
+ res[b].add(ids)
368
+ return res
369
+
370
+ def _iou(self, a, b):
371
+ """
372
+ a, b: tuple(ids)
373
+ """
374
+ set_a, set_b = set(a), set(b)
375
+ inter = len(set_a & set_b)
376
+ union = len(set_a | set_b)
377
+ if union == 0:
378
+ return 0.0
379
+ return inter / union
380
+
381
+ def _boundary_error(self, pred, gold):
382
+ """
383
+ đo lệch boundary dựa trên overlap prefix/suffix
384
+ """
385
+ # left match
386
+ left = 0
387
+ for i in range(min(len(pred), len(gold))):
388
+ if pred[i] == gold[i]:
389
+ left += 1
390
+ else:
391
+ break
392
+
393
+ # right match
394
+ right = 0
395
+ for i in range(1, min(len(pred), len(gold)) + 1):
396
+ if pred[-i] == gold[-i]:
397
+ right += 1
398
+ else:
399
+ break
400
+
401
+ return {
402
+ "left_match": left,
403
+ "right_match": right,
404
+ "pred_len": len(pred),
405
+ "gold_len": len(gold),
406
+ }
407
+
408
+ # ===== main =====
409
+ def analyze(self, preds, golds):
410
+ pred_map = self._to_set(preds)
411
+ gold_map = self._to_set(golds)
412
+
413
+ all_batches = set(pred_map.keys()) | set(gold_map.keys())
414
+
415
+ stats = Counter()
416
+
417
+ detailed_errors = []
418
+
419
+ for b in all_batches:
420
+ pset = pred_map.get(b, set())
421
+ gset = gold_map.get(b, set())
422
+
423
+ matched_gold = set()
424
+
425
+ # ===== check predictions =====
426
+ for p in pset:
427
+ if p in gset:
428
+ stats["exact_match"] += 1
429
+ matched_gold.add(p)
430
+ else:
431
+ # tìm gold gần nhất
432
+ best_iou = 0
433
+ best_g = None
434
+
435
+ for g in gset:
436
+ iou = self._iou(p, g)
437
+ if iou > best_iou:
438
+ best_iou = iou
439
+ best_g = g
440
+
441
+ if best_iou > 0:
442
+ stats["partial_match"] += 1
443
+
444
+ boundary = self._boundary_error(p, best_g)
445
+
446
+ detailed_errors.append({
447
+ "type": "boundary_error",
448
+ "batch": b,
449
+ "pred": p,
450
+ "gold": best_g,
451
+ "iou": best_iou,
452
+ **boundary
453
+ })
454
+ else:
455
+ if b not in gold_map:
456
+ stats["no_event_sample"] += 1
457
+ err_type = "no_event_sample"
458
+ else:
459
+ stats["completely_wrong"] += 1
460
+ err_type = "completely_wrong"
461
+
462
+ detailed_errors.append({
463
+ "type": err_type,
464
+ "batch": b,
465
+ "pred": p
466
+ })
467
+
468
+ # ===== check missing =====
469
+ for g in gset:
470
+ if g not in matched_gold:
471
+ # check if any pred overlaps
472
+ overlap = any(self._iou(p, g) > 0 for p in pset)
473
+
474
+ if overlap:
475
+ stats["miss_with_overlap"] += 1
476
+ else:
477
+ stats["miss"] += 1
478
+
479
+ detailed_errors.append({
480
+ "type": "miss",
481
+ "batch": b,
482
+ "gold": g
483
+ })
484
+
485
+ return {
486
+ "summary": {
487
+ "exact_match": (stats["exact_match"], stats["exact_match"] / len(preds)),
488
+ "partial_match": (stats["partial_match"], stats["partial_match"] / len(preds)),
489
+ "no_event_sample": (stats["no_event_sample"], stats["no_event_sample"] / len(preds)),
490
+ "completely_wrong": (stats["completely_wrong"], stats["completely_wrong"] / len(preds)),
491
+ "miss": (stats["miss"], stats["miss"] / len(golds)),
492
+ "miss_with_overlap": (stats["miss_with_overlap"], stats["miss_with_overlap"] / len(golds)),
493
+ },
494
+ "details": detailed_errors
495
+ }
496
+
497
+ # %% [code]
498
+ class DataParallelProxy(nn.DataParallel):
499
+ def __getattr__(self, name):
500
+ try:
501
+ return super().__getattr__(name)
502
+
503
+ except AttributeError:
504
+
505
+ attr = getattr(self.module, name)
506
+
507
+ if callable(attr):
508
+
509
+ def wrapper(*args, **kwargs):
510
+ return self._parallel_apply_method(
511
+ name,
512
+ *args,
513
+ **kwargs
514
+ )
515
+
516
+ return wrapper
517
+
518
+ return attr
519
+
520
+ def _parallel_apply_method(self, method_name, *inputs, **kwargs):
521
+ if not self.device_ids:
522
+ return getattr(self.module, method_name)(*inputs, **kwargs)
523
+
524
+ inputs_scattered, kwargs_scattered = self.scatter(
525
+ inputs,
526
+ kwargs,
527
+ self.device_ids
528
+ )
529
+
530
+ replicas = self.replicate(
531
+ self.module,
532
+ self.device_ids[:len(inputs_scattered)]
533
+ )
534
+
535
+ outputs = self.parallel_apply(
536
+ [getattr(replica, method_name) for replica in replicas],
537
+ inputs_scattered,
538
+ kwargs_scattered
539
+ )
540
+
541
+ return self._custom_gather(outputs, self.output_device)
542
+
543
+ def gather(self, outputs, output_device):
544
+ return self._custom_gather(outputs, output_device)
545
+
546
+ def _custom_gather(self, outputs, output_device):
547
+ first = outputs[0]
548
+
549
+ if torch.is_tensor(first):
550
+ return self._gather_tensor(outputs, output_device)
551
+
552
+ if isinstance(first, tuple):
553
+ return tuple(
554
+ self._custom_gather(
555
+ list(items),
556
+ output_device
557
+ )
558
+ for items in zip(*outputs)
559
+ )
560
+
561
+ if isinstance(first, list):
562
+ if len(first) > 0 and torch.is_tensor(first[0]):
563
+ return self._gather_tensor_list(outputs, output_device)
564
+
565
+ merged = []
566
+ for out in outputs:
567
+ merged.extend(out)
568
+ return merged
569
+
570
+ if isinstance(first, dict):
571
+ return {
572
+ k: self._custom_gather(
573
+ [o[k] for o in outputs],
574
+ output_device
575
+ )
576
+ for k in first.keys()
577
+ }
578
+ return outputs
579
+
580
+ def _gather_tensor(self, tensors, output_device):
581
+ tensors = [
582
+ t.to(output_device)
583
+ for t in tensors
584
+ ]
585
+
586
+ try:
587
+ return torch.cat(tensors, dim=0)
588
+ except RuntimeError:
589
+ pass
590
+
591
+ max_shape = list(tensors[0].shape)
592
+ for t in tensors[1:]:
593
+ for d in range(len(max_shape)):
594
+ max_shape[d] = max(max_shape[d], t.shape[d])
595
+
596
+ padded = []
597
+ for t in tensors:
598
+ pad = []
599
+
600
+ for d in reversed(range(len(max_shape))):
601
+ if d == 0:
602
+ pad.extend([0, 0])
603
+ continue
604
+
605
+ diff = max_shape[d] - t.shape[d]
606
+ pad.extend([0, diff])
607
+
608
+ t = F.pad(t, pad)
609
+ padded.append(t)
610
+ return torch.cat(padded, dim=0)
611
+
612
+ def _gather_tensor_list(self, outputs, output_device):
613
+ merged = []
614
+
615
+ for out in outputs:
616
+ merged.extend(out)
617
+
618
+ return self._gather_tensor(merged, output_device)
619
+
620
+ # %% [code]
621
+ ## Viết cấu trúc model vào đây
622
+ def extract_spans_and_labels(start_logits, end_logits):
623
+ """
624
+ Args:
625
+ start_logits: Tensor (B, L, C)
626
+ end_logits: Tensor (B, L, C)
627
+
628
+ Returns:
629
+ spans: Tensor (B, N, 2)
630
+ padding = (0, 0)
631
+
632
+ labels: Tensor (B, N)
633
+ padding = -100
634
+
635
+ Nếu không extract được span nào:
636
+ spans.shape = (B, 0, 2)
637
+ labels.shape = (B, 0)
638
+ """
639
+
640
+ start_labels = start_logits.argmax(dim=-1) # (B, L)
641
+ end_labels = end_logits.argmax(dim=-1) # (B, L)
642
+
643
+ B, L = start_labels.shape
644
+
645
+ batch_spans = []
646
+ batch_labels = []
647
+ max_n = 0
648
+ for bidx in range(B):
649
+ used_start = set()
650
+ used_end = set()
651
+
652
+ spans = []
653
+ labels = []
654
+
655
+ for s in range(L):
656
+ s_label = start_labels[bidx, s].item()
657
+ if s_label == 0:
658
+ continue
659
+ if s in used_start:
660
+ continue
661
+
662
+ nearest_e = None
663
+
664
+ # tìm end gần nhất cùng class
665
+ for e in range(s, L):
666
+ if e in used_end:
667
+ continue
668
+
669
+ e_label = end_labels[bidx, e].item()
670
+
671
+ if e_label == s_label:
672
+ nearest_e = e
673
+ break
674
+
675
+ if nearest_e is None:
676
+ continue
677
+
678
+ used_start.add(s)
679
+ used_end.add(nearest_e)
680
+
681
+ spans.append([s, nearest_e])
682
+ labels.append(s_label)
683
+
684
+ batch_spans.append(spans)
685
+ batch_labels.append(labels)
686
+
687
+ max_n = max(max_n, len(spans))
688
+
689
+ if max_n == 0:
690
+ spans = torch.empty((B, 0, 2), dtype=torch.long, device=start_logits.device)
691
+ labels = torch.empty((B, 0), dtype=torch.long, device=start_logits.device)
692
+ return spans, labels
693
+
694
+ padded_spans = []
695
+ padded_labels = []
696
+ for spans, labels in zip(batch_spans, batch_labels):
697
+ pad_n = max_n - len(spans)
698
+
699
+ spans = spans + [[0, 0]] * pad_n
700
+ labels = labels + [-100] * pad_n
701
+
702
+ padded_spans.append(spans)
703
+ padded_labels.append(labels)
704
+
705
+ spans = torch.tensor(padded_spans, dtype=torch.long, device=start_logits.device) # (B, N, 2)
706
+ labels = torch.tensor(padded_labels, dtype=torch.long, device=start_logits.device) # (B, N)
707
+ return spans, labels
708
+
709
+ def expand_spans(spans, r, attention_mask):
710
+ """
711
+ Args:
712
+ spans: Tensor (B, N, 2)
713
+ r: Tensor (B, N)
714
+ attention_mask: Tensor (B, L)
715
+
716
+ Returns:
717
+ expanded_spans: Tensor (B, N, M, 2)
718
+
719
+ Với:
720
+ M = (2*max_r + 1)^2
721
+
722
+ padding = (0, 0)
723
+
724
+ Nếu N = 0:
725
+ return shape (B, 0, 0, 2)
726
+ """
727
+
728
+ device = spans.device
729
+ B, N, _ = spans.shape
730
+
731
+ if N == 0:
732
+ return torch.empty((B, 0, 0, 2), dtype=torch.long, device=device)
733
+ max_r = r.max().item()
734
+
735
+ shifts = torch.arange(-max_r, max_r + 1, device=device)
736
+ ds_grid, de_grid = torch.meshgrid(shifts, shifts, indexing='ij')
737
+ ds = ds_grid.reshape(-1) # (M,)
738
+ de = de_grid.reshape(-1) # (M,)
739
+ M = ds.size(0)
740
+
741
+ spans = spans.unsqueeze(2) # (B, N, 1, 2)
742
+ start = spans[..., 0] # (B, N, 1)
743
+ end = spans[..., 1] # (B, N, 1)
744
+
745
+ expanded_s = start + ds.view(1, 1, M)
746
+ expanded_e = end + de.view(1, 1, M)
747
+ expanded = torch.stack( [expanded_s, expanded_e], dim=-1) # (B, N, M, 2)
748
+
749
+ valid_shift = ((ds.abs().view(1, 1, M) <= r.unsqueeze(-1)) & (de.abs().view(1, 1, M) <= r.unsqueeze(-1)))
750
+ non_pad_span = (spans.squeeze(2).sum(dim=-1) != 0) # (B, N)
751
+ valid_length = attention_mask.sum(dim=-1).view(B, 1, 1)
752
+ valid_boundary = ((expanded_s > 0) & (expanded_e < valid_length) & (expanded_s <= expanded_e))
753
+ valid_mask = (valid_shift & non_pad_span.unsqueeze(-1) & valid_boundary)
754
+
755
+ expanded_flat = expanded.reshape(B, N, M, 2)
756
+ valid_flat = valid_mask.reshape(B, N, M)
757
+ valid_counts = valid_flat.sum(dim=-1) # (B, N)
758
+ M_max = valid_counts.max().item()
759
+
760
+ filtered_expanded = torch.zeros((B, N, M_max, 2), dtype=expanded.dtype, device=device)
761
+ for b in range(B):
762
+ for n in range(N):
763
+ cur_valid = valid_flat[b, n] # (M,)
764
+ valid_spans = expanded_flat[b, n][cur_valid]
765
+
766
+ k = valid_spans.size(0)
767
+ if k > 0:
768
+ filtered_expanded[b, n, :k] = valid_spans
769
+
770
+ return filtered_expanded
771
+
772
+ def get_span_reprs(hidden, spans):
773
+ """
774
+ Args:
775
+ hidden: (B, L, H)
776
+ spans: (B, N, M, 2)
777
+
778
+ Return:
779
+ span_repr: (B, N, M, 4H)
780
+
781
+ Nếu N = 0:
782
+ return shape (B, 0, 0, 4H)
783
+ """
784
+
785
+ B, N, M, _ = spans.shape
786
+ H = hidden.size(-1)
787
+
788
+ if N == 0:
789
+ return torch.empty( (B, 0, 0, 4 * H), dtype=hidden.dtype, device=hidden.device)
790
+
791
+ batch_idx = torch.arange(B, device=hidden.device).view(B, 1, 1)
792
+ start_idx = spans[..., 0] # (B, N, M)
793
+ end_idx = spans[..., 1] # (B, N, M)
794
+
795
+ start_h = hidden[batch_idx, start_idx]
796
+ end_h = hidden[batch_idx, end_idx]
797
+ span_repr = torch.cat([start_h, end_h, end_h - start_h, end_h * start_h], dim=-1) # (B, N, M, 4H)
798
+ return span_repr
799
+
800
+
801
+ class MLP(nn.Module):
802
+ def __init__(self, in_size, hid_size, out_size):
803
+ super().__init__()
804
+ self.mlp = nn.Sequential(
805
+ nn.Linear(in_size, hid_size),
806
+ nn.ReLU(),
807
+ nn.Linear(hid_size, out_size)
808
+ )
809
+
810
+ def forward(self, x):
811
+ return self.mlp(x)
812
+
813
+
814
+ class PointerBranch(nn.Module):
815
+ def __init__(self, backbone_model_name, num_labels):
816
+ super().__init__()
817
+ self.encoder = AutoModel.from_pretrained(backbone_model_name)
818
+ hidden_size = self.encoder.config.hidden_size
819
+
820
+ self.start_classifier = MLP(hidden_size, hidden_size, num_labels)
821
+ self.end_classifier = MLP(hidden_size, hidden_size, num_labels)
822
+
823
+ def encode(self, input_ids, attention_mask):
824
+ B, n_parts, L = input_ids.shape
825
+ input_ids = input_ids.view(-1, L)
826
+ attention_mask = attention_mask.view(-1, L)
827
+
828
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
829
+ hidden_states = outputs.last_hidden_state # B * n_parts, L, H
830
+
831
+ hidden_states = hidden_states.view(B, n_parts, L, -1).reshape(B, n_parts*L, -1) # B, L, H
832
+ return hidden_states
833
+
834
+ def get_logits(self, hidden_states):
835
+ start_logits = self.start_classifier(hidden_states) # B, N, classes
836
+ end_logits = self.end_classifier(hidden_states) # B, N, classes
837
+ return start_logits, end_logits
838
+
839
+ def forward(self, input_ids, attention_mask):
840
+ hidden_states = self.encode(input_ids, attention_mask)
841
+ start_logits, end_logits = self.get_logits(hidden_states)
842
+ return start_logits, end_logits
843
+
844
+
845
+ class SpanScoreBranch(nn.Module):
846
+ def __init__(self, backbone_model_name, max_r):
847
+ super().__init__()
848
+ self.max_r = max_r
849
+
850
+ self.encoder = AutoModel.from_pretrained(backbone_model_name)
851
+ hidden_size = self.encoder.config.hidden_size
852
+
853
+ self.span_scorer = MLP(4*hidden_size, hidden_size, 1)
854
+
855
+ def encode(self, input_ids, attention_mask):
856
+ B, n_parts, L = input_ids.shape
857
+ input_ids = input_ids.view(-1, L)
858
+ attention_mask = attention_mask.view(-1, L)
859
+
860
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
861
+ hidden_states = outputs.last_hidden_state # B * n_parts, L, H
862
+
863
+ hidden_states = hidden_states.view(B, n_parts, L, -1).reshape(B, n_parts*L, -1) # B, L, H
864
+ attention_mask = attention_mask.view(B, n_parts, L).reshape(B, n_parts*L) # B, L
865
+ return hidden_states, attention_mask
866
+
867
+ def get_scores(self, expanded_span_reprs):
868
+ return self.span_scorer(expanded_span_reprs).squeeze(-1)
869
+
870
+ def forward(self, input_ids, attention_mask, spans, use_max_r=False):
871
+ hidden_states, attention_mask = self.encode(input_ids, attention_mask)
872
+
873
+ B, N, _ = spans.shape
874
+ if use_max_r:
875
+ r = torch.ones((B, N), device=spans.device, dtype=torch.long) * self.max_r
876
+ else:
877
+ r = torch.randint(1, self.max_r+1, (B, N), device=spans.device)
878
+ expanded_spans = expand_spans(spans, r, attention_mask) # (B, N, M, 2)
879
+
880
+ expanded_span_reprs = get_span_reprs(hidden_states, expanded_spans) # (B, N, M, 4H)
881
+ expanded_span_scores = self.get_scores(expanded_span_reprs) # (B, N, M)
882
+ return expanded_spans, expanded_span_scores
883
+
884
+
885
+ class IEModel(nn.Module):
886
+ def __init__(self, backbone_model_name, num_labels, max_r):
887
+ super().__init__()
888
+ self.max_r = max_r
889
+ self.pointer = PointerBranch(backbone_model_name, num_labels)
890
+ self.span_score = SpanScoreBranch(backbone_model_name, max_r)
891
+
892
+ def forward(self, input_ids, attention_mask, spans=None, use_max_r=False):
893
+ start_logits, end_logits = self.pointer(input_ids, attention_mask)
894
+
895
+ if spans is None:
896
+ spans, _ = extract_spans_and_labels(start_logits, end_logits) # (B, N, 2)
897
+ expanded_spans, expanded_span_scores = self.span_score(input_ids, attention_mask, spans, use_max_r)
898
+
899
+ return start_logits, end_logits, expanded_spans, expanded_span_scores
900
+
901
+ def test_model():
902
+ model = DataParallelProxy(IEModel(backbone_model_name, 7, 6)).to(device)
903
+ model.eval()
904
+ total_params = sum(p.numel() for p in model.parameters())
905
+ print(f"Total params: {total_params:,}")
906
+
907
+ vocab_size = 120
908
+
909
+ bz = 32
910
+ i = torch.randint(0, vocab_size, (bz, 5, 10)).to(device)
911
+ a = torch.ones(bz, 5, 10).to(device)
912
+ g = torch.ones(bz, 3, 2, dtype=torch.long).to(device)
913
+
914
+ with torch.no_grad():
915
+ r = model(i, a)
916
+
917
+ if type(r) == tuple:
918
+ print([r[i].shape if type(r[i]) == type(torch.Tensor()) else len(r[i]) for i in range(len(r))])
919
+ else:
920
+ print(r.shape)
921
+
922
+ test_model()
923
+
924
+ # %% [code]
925
+ def configure_optimizers(network, optim_params, scheduler_params):
926
+ try:
927
+ optim_params = copy.copy(optim_params)
928
+ scheduler_params = copy.copy(scheduler_params)
929
+
930
+ optim_name = optim_params.pop('name')
931
+ scheduler_name = scheduler_params.pop('name')
932
+
933
+ optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None)
934
+ scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None)
935
+
936
+ if optimizer_cls is None:
937
+ raise ValueError(f"Optimizer '{optim_name}' is not available!")
938
+
939
+ optimizer = optimizer_cls(network.parameters(), **optim_params)
940
+
941
+ scheduler = None
942
+ if scheduler_params and scheduler_cls: # Chỉ tạo scheduler nếu có tham số
943
+ scheduler = scheduler_cls(optimizer, **scheduler_params)
944
+
945
+ return optimizer, scheduler
946
+
947
+ except KeyError as e:
948
+ raise ValueError(f"Missing {e} in config!!")
949
+
950
+ def freeze(self, model):
951
+ model.eval()
952
+ for param in model.parameters():
953
+ param.requires_grad = False
954
+
955
+ def unfreeze(self, model):
956
+ model.train()
957
+ for param in model.parameters():
958
+ param.requires_grad = True
959
+
960
+ def reduce_batch_size(loader, ratio=0.5):
961
+ new_bs = max(1, int(loader.batch_size * ratio))
962
+
963
+ shuffle = isinstance(loader.sampler, RandomSampler)
964
+
965
+ new_loader = DataLoader(
966
+ dataset=loader.dataset,
967
+ batch_size=new_bs,
968
+ shuffle=shuffle,
969
+ sampler=None if shuffle else loader.sampler,
970
+ num_workers=loader.num_workers,
971
+ collate_fn=loader.collate_fn,
972
+ pin_memory=loader.pin_memory,
973
+ drop_last=loader.drop_last,
974
+ timeout=loader.timeout,
975
+ worker_init_fn=loader.worker_init_fn,
976
+ multiprocessing_context=loader.multiprocessing_context,
977
+ generator=loader.generator,
978
+ prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None,
979
+ persistent_workers=loader.persistent_workers,
980
+ pin_memory_device=loader.pin_memory_device
981
+ )
982
+
983
+ return new_loader
984
+
985
+ def list_to_tuple(x):
986
+ if isinstance(x, (list, tuple)):
987
+ return tuple(list_to_tuple(i) for i in x)
988
+ return x
989
+
990
+ def fmt(x):
991
+ if isinstance(x, float):
992
+ return round(x, 5)
993
+ if isinstance(x, dict):
994
+ return {k: fmt(v) for k, v in x.items()}
995
+ if isinstance(x, list):
996
+ return [fmt(v) for v in x]
997
+ return x
998
+
999
+ class ModelEmaV3Proxy(ModelEmaV3):
1000
+ def __getattr__(self, name):
1001
+ try:
1002
+ return super().__getattr__(name)
1003
+ except AttributeError:
1004
+ return getattr(self.module, name)
1005
+
1006
+ def align(spans, gold_spans, gold_labels):
1007
+ """
1008
+ Args:
1009
+ spans: (B, N1, M, 2)
1010
+ gold_spans: (B, N2, 2)
1011
+ gold_labels: (B, N2)
1012
+
1013
+ Returns:
1014
+ labels: (B, N1, M)
1015
+
1016
+ matched -> gold_labels
1017
+ padding (0,0) -> -100
1018
+ unmatched -> 0
1019
+ """
1020
+
1021
+ device = spans.device
1022
+ B, N1, M, _ = spans.shape
1023
+ N2 = gold_spans.shape[1]
1024
+
1025
+ pred = spans.unsqueeze(3) # (B, N1, M, 1, 2)
1026
+ gold = gold_spans.unsqueeze(1).unsqueeze(1) # (B, 1, 1, N2, 2)
1027
+ matched = (pred == gold).all(dim=-1) # (B, N1, M, N2)
1028
+
1029
+ has_match = matched.any(dim=-1) # (B, N1, M)
1030
+ matched_idx = matched.float().argmax(dim=-1) # (B, N1, M)
1031
+
1032
+ expanded_gold_labels = gold_labels.unsqueeze(1).unsqueeze(1).expand(B, N1, M, N2)
1033
+ index = matched_idx.unsqueeze(-1)
1034
+ gathered_labels = torch.gather(expanded_gold_labels, dim=-1, index=index).squeeze(-1)
1035
+
1036
+ labels = torch.zeros((B, N1, M), dtype=gold_labels.dtype, device=device)
1037
+ labels = torch.where(has_match, gathered_labels, labels)
1038
+
1039
+ is_padding = (spans.sum(dim=-1) == 0)
1040
+ labels = torch.where(is_padding, torch.full_like(labels, -100), labels)
1041
+ return labels
1042
+
1043
+ def select_best_spans(spans, span_scores):
1044
+ """
1045
+ Args:
1046
+ spans: (B, N, M, 2)
1047
+ span_scores: (B, N, M)
1048
+
1049
+ Returns:
1050
+ best_spans: (B, N, 2)
1051
+ """
1052
+
1053
+ is_padding = (spans.sum(dim=-1) == 0) # (B, N, M)
1054
+ masked_scores = span_scores.masked_fill(is_padding, float("-inf"))
1055
+
1056
+ best_idx = masked_scores.argmax(dim=-1) # best index theo chiều M
1057
+ index = best_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, 2)
1058
+ best_spans = torch.gather(spans, dim=2, index=index).squeeze(2)
1059
+
1060
+ is_padding = is_padding.all(dim=-1) # (B, N)
1061
+ best_spans = torch.where(is_padding.unsqueeze(-1), torch.zeros_like(best_spans), best_spans)
1062
+ return best_spans
1063
+
1064
+ def extract_entities(input_ids, spans, labels, id2label):
1065
+ """
1066
+ Args:
1067
+ input_ids: (B, L)
1068
+ spans: (B, N, 2)
1069
+ labels: (B, N)
1070
+
1071
+ Returns:
1072
+ List[(bidx, (tuple(token_ids), label_name))]
1073
+ """
1074
+
1075
+ B, N, _ = spans.shape
1076
+
1077
+ results = []
1078
+ for bidx in range(B):
1079
+ for n in range(N):
1080
+ lb = labels[bidx, n].item()
1081
+ if lb <= 0:
1082
+ continue
1083
+
1084
+ s, e = spans[bidx, n].tolist()
1085
+ if s == 0 and e == 0:
1086
+ continue
1087
+
1088
+ token_ids = tuple(input_ids[bidx, s:e + 1].tolist())
1089
+ results.append((bidx, (token_ids, id2label[lb])))
1090
+
1091
+ return results
1092
+
1093
+ class Trainer:
1094
+ def __init__(
1095
+ self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0,
1096
+ logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu',
1097
+ schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True
1098
+ ):
1099
+ self.ema_net = None
1100
+
1101
+ self.training_time = self._time_str_to_seconds(training_time)
1102
+ self.mode = eval_mode
1103
+ self.topk = topk
1104
+ self.device = device
1105
+ self.logging = logging if logging < epochs else 1
1106
+ self.logging_file = logging_file
1107
+ self.checkpoints_dir = checkpoints_dir
1108
+ self.early_stopping = early_stopping
1109
+ self.eval_from_ratio = eval_from_ratio
1110
+ self.eval_every = eval_every
1111
+ self.save_name = save_name
1112
+ self.save_best = save_best
1113
+ self.save_last = save_last
1114
+ self.return_best = return_best
1115
+ self.return_last = return_last
1116
+ self.max_grad_norm = max_grad_norm
1117
+ self.schedule_in_step = schedule_in_step
1118
+ self.use_ema = use_ema
1119
+ self.ema_from_ratio = ema_from_ratio
1120
+ self.ema_decay = ema_decay
1121
+
1122
+ self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]]
1123
+ self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0)
1124
+
1125
+ def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, id2label=None):
1126
+ if eval_fn is None:
1127
+ if self.mode == "max":
1128
+ eval_fn = lambda *x: -loss_fn(*x)
1129
+ else:
1130
+ eval_fn = lambda *x: loss_fn(*x)
1131
+
1132
+ if torch.cuda.device_count() > 1:
1133
+ network = DataParallelProxy(network)
1134
+ network = network.to(self.device)
1135
+
1136
+ if not start_training_time:
1137
+ start_training_time = time.time()
1138
+
1139
+ start_ema = int(epochs * self.ema_from_ratio)
1140
+ start_eval = int(epochs * self.eval_from_ratio)
1141
+
1142
+ if val_loader is None:
1143
+ print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!')
1144
+ else:
1145
+ model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc'
1146
+ start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else ''
1147
+ print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str)
1148
+
1149
+ training_log = {}
1150
+ for epoch in range(start_epoch, epochs+start_epoch):
1151
+ if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema:
1152
+ self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
1153
+
1154
+ try:
1155
+ teaching_rate = math.cos(math.pi / 2 * epoch / epochs)
1156
+ train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, teaching_rate)
1157
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
1158
+ logging_dict.update(train_loss_epoch_dict)
1159
+
1160
+ if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0:
1161
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1162
+
1163
+ val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, id2label)
1164
+ update = self._update_best_network(eval_net, val_score, epoch)
1165
+ logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update})
1166
+ logging_dict.update(val_score_dict)
1167
+ if not self.schedule_in_step and scheduler:
1168
+ scheduler.step()
1169
+
1170
+ except RuntimeError as e:
1171
+ if "out of memory" in str(e).lower():
1172
+ print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...")
1173
+ torch.cuda.empty_cache()
1174
+ gc.collect()
1175
+ if torch.cuda.is_available():
1176
+ torch.cuda.synchronize()
1177
+ print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared")
1178
+
1179
+ train_loader = reduce_batch_size(train_loader, ratio=0.5)
1180
+ if val_loader is not None:
1181
+ val_loader = reduce_batch_size(val_loader, ratio=0.5)
1182
+
1183
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')}
1184
+ else:
1185
+ raise
1186
+
1187
+ training_log[epoch] = logging_dict
1188
+ if self.is_early_stopping(epoch):
1189
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...')
1190
+ break
1191
+ if self.logging:
1192
+ if epoch % self.logging == 0:
1193
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict))
1194
+ else:
1195
+ print(f'{epoch}...', end=' ')
1196
+
1197
+ if self._at_time_limit(start_training_time):
1198
+ print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}')
1199
+ break
1200
+
1201
+ if self.logging_file:
1202
+ os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True)
1203
+ with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f:
1204
+ f.write(json.dumps(training_log))
1205
+
1206
+ if self.use_ema and self.ema_net is not None:
1207
+ self._save_state_dict(self.ema_net.module)
1208
+ else:
1209
+ self._save_state_dict(network)
1210
+ print(f'[Trainer CallBack] 📢 Kết thúc training.\n')
1211
+
1212
+ best_model, last_model = None, None
1213
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1214
+ if self.return_best :
1215
+ best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict()
1216
+ best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()}
1217
+ if self.return_last:
1218
+ last_model = eval_net.state_dict()
1219
+ last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()}
1220
+
1221
+ del network
1222
+ torch.cuda.empty_cache()
1223
+ gc.collect()
1224
+ return training_log, best_model, last_model
1225
+
1226
+ def _time_str_to_seconds(self, time_str):
1227
+ days, hours, minutes, seconds = map(int, time_str.split(":"))
1228
+ return days * 86400 + hours * 3600 + minutes * 60 + seconds
1229
+
1230
+ def _update_best_network(self, network, val_score, epoch):
1231
+ topk = max(1, self.topk)
1232
+ self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}])
1233
+ self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk]
1234
+ if val_score in [x[0] for x in self.best_stage]:
1235
+ return True
1236
+ return False
1237
+
1238
+ def is_early_stopping(self, epoch):
1239
+ if self.best_stage[0][1] is None:
1240
+ return False
1241
+ if not self.early_stopping:
1242
+ return False
1243
+ return epoch - self.best_stage[0][1] >= self.early_stopping
1244
+
1245
+ def _at_time_limit(self, start_training_time):
1246
+ return time.time() - start_training_time >= self.training_time
1247
+
1248
+ def _save_state_dict(self, network):
1249
+ if self.topk <= 0:
1250
+ return
1251
+
1252
+ if self.save_best:
1253
+ for r in range(self.topk):
1254
+ os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True)
1255
+
1256
+ for rank, (score, epoch, state_dict) in enumerate(self.best_stage):
1257
+ if state_dict is None:
1258
+ continue
1259
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()}
1260
+ torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth')
1261
+ if self.save_last:
1262
+ os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True)
1263
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()}
1264
+ torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth')
1265
+
1266
+ def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn, teaching_rate):
1267
+ network.train()
1268
+ total_loss = 0
1269
+ total_loss_dict = {}
1270
+ for batch_idx, batch in enumerate(train_loader):
1271
+ optimizer.zero_grad()
1272
+ with torch.autocast(device_type=self.device, dtype=torch.float16):
1273
+ loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn, teaching_rate)
1274
+
1275
+ for k, v in loss_dict.items():
1276
+ t = total_loss_dict.get(k, 0)
1277
+ total_loss_dict[k] = t + v
1278
+ self.grad_scaler.scale(loss).backward()
1279
+ self.grad_scaler.unscale_(optimizer)
1280
+ grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm)
1281
+ # print(grad_norm) # Bỏ cmt dòng này để biết nên chọn max_grad_norm bằng bao nhiêu...
1282
+ self.grad_scaler.step(optimizer)
1283
+ self.grad_scaler.update()
1284
+ if self.schedule_in_step and scheduler:
1285
+ scheduler.step()
1286
+ if self.use_ema and self.ema_net is not None:
1287
+ self.ema_net.update(network)
1288
+ total_loss += loss
1289
+ return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()}
1290
+
1291
+ def _eval_epoch(self, network, val_loader, eval_fn, id2label):
1292
+ network.eval()
1293
+ total_score = 0.0
1294
+ total_score_dict = {}
1295
+ object_lists = None # sẽ init sau
1296
+
1297
+ with torch.no_grad():
1298
+ for batch_idx, batch in enumerate(val_loader):
1299
+ score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, id2label)
1300
+ total_score += score
1301
+
1302
+ for k, v in score_dict.items():
1303
+ t = total_score_dict.get(k, 0)
1304
+ total_score_dict[k] = t + v
1305
+
1306
+ if objects:
1307
+ if object_lists is None:
1308
+ object_lists = [[] for _ in range(len(objects))]
1309
+
1310
+ for i, obj in enumerate(objects):
1311
+ object_lists[i].append(obj.detach())
1312
+
1313
+ if object_lists is not None:
1314
+ object_arrays = [
1315
+ torch.concat(obj_list, dim=0).cpu().numpy()
1316
+ for obj_list in object_lists
1317
+ ]
1318
+ else:
1319
+ object_arrays = []
1320
+
1321
+ return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays
1322
+
1323
+ def _cal_loss(self, network, batch, batch_idx, loss_fn, teaching_rate):
1324
+ # Bạn cần override _cal_loss để tính loss
1325
+ input_ids = batch['input_ids'].to(self.device)
1326
+ attention_mask = batch['attention_mask'].to(self.device)
1327
+ gold_spans = batch['gold_spans'].to(self.device)
1328
+ gold_labels = batch['gold_labels'].to(self.device)
1329
+ start_labels = batch['start_labels'].to(self.device)
1330
+ end_labels = batch['end_labels'].to(self.device)
1331
+
1332
+ choice = random.random()
1333
+ if choice < teaching_rate:
1334
+ start_logits, end_logits, spans, span_scores = network(input_ids, attention_mask, gold_spans)
1335
+ else:
1336
+ start_logits, end_logits, spans, span_scores = network(input_ids, attention_mask)
1337
+
1338
+ labels = align(spans, gold_spans, gold_labels)
1339
+
1340
+ loss_dict = loss_fn(
1341
+ start_logits, start_labels,
1342
+ end_logits, end_labels,
1343
+ span_scores, labels
1344
+ )
1345
+ return loss_dict['total'], loss_dict
1346
+
1347
+ def _cal_val_score(self, network, batch, batch_idx, eval_fn, id2label):
1348
+ # Bạn cần override _cal_val_score để tính val score, list bên cạnh là để trả về y hay pred gì đó (nếu cần)
1349
+ input_ids = batch['input_ids'].to(self.device)
1350
+ attention_mask = batch['attention_mask'].to(self.device)
1351
+ gold_entities = batch['gold_entities']
1352
+
1353
+ B, _, _ = input_ids.shape
1354
+
1355
+ start_logits, end_logits, spans, span_scores = network(input_ids, attention_mask, use_max_r=True)
1356
+ _, labels = extract_spans_and_labels(start_logits, end_logits)
1357
+ spans = select_best_spans(spans, span_scores)
1358
+
1359
+ pred_ids = extract_entities(input_ids.reshape(B, -1), spans, labels, id2label)
1360
+ pred_ids = list_to_tuple(pred_ids)
1361
+
1362
+ gold_ids = list_to_tuple(gold_entities)
1363
+
1364
+ score_dict = eval_fn(pred_ids, gold_ids)
1365
+ return score_dict['f1'], score_dict, []
1366
+
1367
+ # %% [code]
1368
+ class PhoBERTSpanAligner:
1369
+ def __init__(self, tokenizer, max_len):
1370
+ self.tokenizer = tokenizer
1371
+ self.max_len = max_len
1372
+
1373
+ # ===== 1. Extract discontinuous spans =====
1374
+ def extract_spans(self, sample):
1375
+ entity_spans = []
1376
+
1377
+ for event in sample["entities"]:
1378
+ entity_type = event["label"]
1379
+ spans = [tuple(event["offset"])]
1380
+ entity_spans.append({
1381
+ "spans": spans,
1382
+ "label": entity_type
1383
+ })
1384
+
1385
+ return entity_spans
1386
+
1387
+ # ===== 2. Word offsets =====
1388
+ def build_word_offsets(self, text, words):
1389
+ offsets = []
1390
+ pointer = 0
1391
+
1392
+ for word in words:
1393
+ start = text.find(word, pointer)
1394
+ end = start + len(word)
1395
+ offsets.append((start, end))
1396
+ pointer = end
1397
+
1398
+ return offsets
1399
+
1400
+ # ===== 3. Char → word =====
1401
+ def char_span_to_word_span(self, word_offsets, start, end):
1402
+ start_word = None
1403
+ end_word = None
1404
+
1405
+ for i, (w_start, w_end) in enumerate(word_offsets):
1406
+ if w_start <= start < w_end:
1407
+ start_word = i
1408
+ if w_start < end <= w_end:
1409
+ end_word = i
1410
+
1411
+ return start_word, end_word
1412
+
1413
+ # ===== 4. Word → subword =====
1414
+ def word_to_subword_map(self, words):
1415
+ mapping = []
1416
+ subword_index = 1 # <s>
1417
+
1418
+ for word in words:
1419
+ sub_tokens = self.tokenizer.tokenize(word)
1420
+ start = subword_index
1421
+ end = subword_index + len(sub_tokens) - 1
1422
+ mapping.append((start, end))
1423
+ subword_index += len(sub_tokens)
1424
+
1425
+ return mapping
1426
+
1427
+ # ===== 5. Span → subword =====
1428
+ def span_to_subword(self, word_offsets, word_subword_map, spans):
1429
+ sub_spans = []
1430
+
1431
+ for span_start, span_end in spans:
1432
+ w_start, w_end = self.char_span_to_word_span(
1433
+ word_offsets, span_start, span_end
1434
+ )
1435
+ if w_start is None or w_end is None:
1436
+ continue
1437
+
1438
+ sub_start = word_subword_map[w_start][0]
1439
+ sub_end = word_subword_map[w_end][1]
1440
+ sub_spans.append((sub_start, sub_end))
1441
+
1442
+ return sub_spans
1443
+
1444
+ def extract_valid_spans(self, sub_spans):
1445
+ valid_spans = []
1446
+ for s, e in sub_spans:
1447
+ if s < 0 or e < 0 or s >= self.max_len or e >= self.max_len or s > e:
1448
+ continue
1449
+ valid_spans.append((s, e))
1450
+ return valid_spans
1451
+
1452
+ def encode(self, sample):
1453
+ text = sample["text"]
1454
+ entities = self.extract_spans(sample)
1455
+
1456
+ # ===== 1. Word tokenize =====
1457
+ words = word_tokenize(text)
1458
+ sentence = " ".join(words)
1459
+
1460
+ # ===== 2. Mapping =====
1461
+ word_offsets = self.build_word_offsets(text, words)
1462
+ word_subword_map = self.word_to_subword_map(words)
1463
+
1464
+ # ===== 3. Tokenize FULL =====
1465
+ encoding = self.tokenizer(
1466
+ sentence,
1467
+ max_length=self.max_len,
1468
+ truncation=True,
1469
+ padding="max_length",
1470
+ return_tensors="pt"
1471
+ )
1472
+ input_ids = encoding["input_ids"][0]
1473
+ attention_mask = encoding["attention_mask"][0]
1474
+
1475
+ # ===== 5. Convert spans =====
1476
+ entities_gold_spans = []
1477
+
1478
+ for ent in entities:
1479
+ label = ent["label"]
1480
+
1481
+ sub_spans = self.span_to_subword(
1482
+ word_offsets,
1483
+ word_subword_map,
1484
+ ent["spans"]
1485
+ )
1486
+ valid_spans = self.extract_valid_spans(sub_spans)
1487
+ if len(valid_spans) == 0:
1488
+ continue
1489
+ entities_gold_spans.append((tuple(valid_spans), label))
1490
+
1491
+ return {
1492
+ "input_ids": input_ids,
1493
+ "attention_mask": attention_mask,
1494
+ "entities_gold_spans": entities_gold_spans,
1495
+ }
1496
+
1497
+ def generate_spans(attention_mask, max_span_len):
1498
+ seq_len = attention_mask.sum().item() - 2
1499
+ spans = []
1500
+ for i in range(1, seq_len+1):
1501
+ for j in range(i, min(i+max_span_len, seq_len+1)):
1502
+ spans.append((i, j))
1503
+ return spans
1504
+
1505
+ def match_gold_labels(
1506
+ gold_spans, # (N, 2)
1507
+ gold_labels, # (N,)
1508
+ pred_spans, # (M, 2)
1509
+ default_label=-100
1510
+ ):
1511
+ """
1512
+ Return:
1513
+ pred_labels: (M,)
1514
+ """
1515
+
1516
+ pred_labels = torch.full(
1517
+ (pred_spans.size(0),),
1518
+ default_label,
1519
+ dtype=gold_labels.dtype,
1520
+ device=gold_labels.device
1521
+ )
1522
+ if gold_spans.size(0) == 0:
1523
+ return pred_labels
1524
+
1525
+ # (M, N)
1526
+ matched = (pred_spans[:, None, :] == gold_spans[None, :, :]).all(dim=-1)
1527
+ has_match = matched.any(dim=1)
1528
+
1529
+ # lấy index gold đầu tiên match
1530
+ gold_idx = matched.float().argmax(dim=1)
1531
+
1532
+ pred_labels[has_match] = gold_labels[gold_idx[has_match]]
1533
+
1534
+ return pred_labels
1535
+
1536
+ class KLTNDataset(Dataset):
1537
+ def __init__(self, all_data, using_idxes, label2id, tokenizer, max_len, max_n_parts):
1538
+ super().__init__()
1539
+ self.tokenizer = tokenizer
1540
+ self.aligner = PhoBERTSpanAligner(tokenizer, max_len*max_n_parts)
1541
+ self.all_data = all_data
1542
+ self.using_idxes = using_idxes
1543
+ self.label2id = label2id
1544
+ self.max_len = max_len
1545
+ self.max_n_parts = max_n_parts
1546
+
1547
+ def __len__(self):
1548
+ return len(self.using_idxes)
1549
+
1550
+ def __getitem__(self, idx):
1551
+ ridx = self.using_idxes[idx]
1552
+ sample = self.all_data[ridx]
1553
+ result = self.aligner.encode(sample)
1554
+
1555
+ input_ids = result["input_ids"].squeeze(0)
1556
+ attention_mask = result["attention_mask"].squeeze(0)
1557
+ entities_gold_spans = result["entities_gold_spans"]
1558
+
1559
+ all_spans = torch.tensor(generate_spans(attention_mask, 10))
1560
+ gold_spans = torch.tensor([spans[0] for spans, _ in entities_gold_spans], dtype=torch.long) if entities_gold_spans else torch.empty(0, 2, dtype=torch.long)
1561
+ gold_labels = torch.tensor([self.label2id[label] for _, label in entities_gold_spans], dtype=torch.long) if entities_gold_spans else torch.empty(0, dtype=torch.long)
1562
+ all_labels = match_gold_labels(
1563
+ gold_spans, # (N, 2)
1564
+ gold_labels, # (N,)
1565
+ all_spans, # (M, 2)
1566
+ default_label=0
1567
+ )
1568
+
1569
+ # Get label
1570
+ gold_entities = []
1571
+ start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
1572
+ end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
1573
+ for spans, label in entities_gold_spans:
1574
+ s, e = spans[0]
1575
+
1576
+ start_labels[s] = self.label2id[f'{label}']
1577
+ end_labels[e] = self.label2id[f'{label}']
1578
+
1579
+ gold_entities.append((tuple(input_ids[s:e+1].tolist()), label))
1580
+
1581
+ input_ids = input_ids.reshape(self.max_n_parts, self.max_len)
1582
+ attention_mask = attention_mask.reshape(self.max_n_parts, self.max_len)
1583
+
1584
+ n_valid_parts = math.ceil(attention_mask.sum().item() / self.max_len)
1585
+ input_ids = input_ids[:n_valid_parts]
1586
+ attention_mask = attention_mask[:n_valid_parts]
1587
+ start_labels = start_labels[:n_valid_parts*self.max_len]
1588
+ end_labels = end_labels[:n_valid_parts*self.max_len]
1589
+
1590
+ return {
1591
+ "input_ids": input_ids,
1592
+ "attention_mask": attention_mask,
1593
+
1594
+ "all_spans": all_spans,
1595
+ "all_labels": all_labels,
1596
+
1597
+ "gold_spans": gold_spans,
1598
+ "gold_labels": gold_labels,
1599
+
1600
+ "start_labels": start_labels,
1601
+ "end_labels": end_labels,
1602
+
1603
+ "gold_entities": gold_entities,
1604
+ }
1605
+
1606
+ def _pad_batch(tensor_list, pad_value=0):
1607
+ """
1608
+ tensor_list: list of tensors
1609
+ mỗi tensor shape: (Nk, n_parts_i, max_len_i)
1610
+
1611
+ return:
1612
+ padded tensor shape: (B, max_Nk, max_n_parts, max_len)
1613
+ """
1614
+
1615
+ # lấy max toàn batch
1616
+ max_Nk = max(t.size(0) for t in tensor_list)
1617
+ max_n_parts = max(t.size(1) for t in tensor_list)
1618
+ max_len = max(t.size(2) for t in tensor_list)
1619
+
1620
+ padded = []
1621
+
1622
+ for t in tensor_list:
1623
+ Nk, n_parts_i, max_len_i = t.shape
1624
+
1625
+ # pad chiều n_parts và max_len trước
1626
+ if n_parts_i < max_n_parts or max_len_i < max_len:
1627
+ new_t = t.new_full(
1628
+ (Nk, max_n_parts, max_len),
1629
+ pad_value
1630
+ )
1631
+ new_t[:, :n_parts_i, :max_len_i] = t
1632
+ t = new_t
1633
+
1634
+ # pad chiều Nk
1635
+ if Nk < max_Nk:
1636
+ pad_tensor = t.new_full(
1637
+ (max_Nk - Nk, max_n_parts, max_len),
1638
+ pad_value
1639
+ )
1640
+ t = torch.cat([t, pad_tensor], dim=0)
1641
+
1642
+ padded.append(t)
1643
+
1644
+ return torch.stack(padded) # (B, max_Nk, max_n_parts, max_len)
1645
+
1646
+ def collate_fn(batch):
1647
+ gold_entities = []
1648
+ for bidx, b in enumerate(batch):
1649
+ for entity in b['gold_entities']:
1650
+ gold_entities.append([bidx, entity])
1651
+
1652
+ input_ids = [b["input_ids"].unsqueeze(-1) for b in batch]
1653
+ attention_mask = [b["attention_mask"].unsqueeze(-1) for b in batch]
1654
+
1655
+ all_spans = [b["all_spans"].unsqueeze(-1) for b in batch]
1656
+ all_labels = [b["all_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1657
+
1658
+ gold_spans = [b["gold_spans"].unsqueeze(-1) for b in batch]
1659
+ gold_labels = [b["gold_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1660
+
1661
+ start_labels = [b["start_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1662
+ end_labels = [b["end_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1663
+
1664
+ # pad theo Nk
1665
+ input_ids = _pad_batch(input_ids, pad_value=0).squeeze(-1)
1666
+ attention_mask = _pad_batch(attention_mask, pad_value=0).squeeze(-1)
1667
+
1668
+ all_spans = _pad_batch(all_spans, pad_value=0).squeeze(-1)
1669
+ all_labels = _pad_batch(all_labels, pad_value=-100).squeeze(-1).squeeze(-1)
1670
+
1671
+ gold_spans = _pad_batch(gold_spans, pad_value=0).squeeze(-1)
1672
+ gold_labels = _pad_batch(gold_labels, pad_value=-100).squeeze(-1).squeeze(-1)
1673
+
1674
+ start_labels = _pad_batch(start_labels, pad_value=-100).squeeze(-1).squeeze(-1)
1675
+ end_labels = _pad_batch(end_labels, pad_value=-100).squeeze(-1).squeeze(-1)
1676
+
1677
+ return {
1678
+ "input_ids": input_ids,
1679
+ "attention_mask": attention_mask,
1680
+
1681
+ "all_spans": all_spans,
1682
+ "all_labels": all_labels,
1683
+
1684
+ "gold_spans": gold_spans,
1685
+ "gold_labels": gold_labels,
1686
+
1687
+ "start_labels": start_labels,
1688
+ "end_labels": end_labels,
1689
+
1690
+ "gold_entities": gold_entities,
1691
+ }
1692
+
1693
+ # %% [code]
1694
+ def shift_bidx(spans, batch_idx):
1695
+ shifted = []
1696
+ for bidx, ent in spans:
1697
+ new_bidx = bidx + batch_idx * batch_size
1698
+ shifted.append((new_bidx, ent))
1699
+ return shifted
1700
+
1701
+ def refactor_entities(entities, save_dict):
1702
+ i, c = [], []
1703
+ for bidx, (ids, lb) in entities:
1704
+ if (bidx, ids) not in i:
1705
+ i.append((bidx, ids))
1706
+
1707
+ if (bidx, (ids, lb)) not in c:
1708
+ c.append((bidx, (ids, lb)))
1709
+
1710
+ save_dict['Ent-I'].extend(i)
1711
+ save_dict['Ent-C'].extend(c)
1712
+
1713
+ def check_spans(
1714
+ all_spans, # (N, 2)
1715
+ all_labels, # (N,)
1716
+ ensemble_start_logits, # (L, C)
1717
+ ensemble_end_logits, # (L, C)
1718
+ expand_dict,
1719
+ max_expand=10
1720
+ ):
1721
+ """
1722
+ Check minimum expansion radius required to achieve recall = 1.
1723
+
1724
+ Args:
1725
+ all_spans: gold spans
1726
+ all_labels: gold labels
1727
+ ensemble_start_logits: (L, C)
1728
+ ensemble_end_logits: (L, C)
1729
+ expand_dict:
1730
+ dict like:
1731
+ {
1732
+ 0: count,
1733
+ 1: count,
1734
+ ...
1735
+ }
1736
+
1737
+ Return:
1738
+ matched_expand_k
1739
+ """
1740
+
1741
+ gold_set = set()
1742
+
1743
+ valid = all_labels > 0
1744
+
1745
+ valid_idxes = valid.nonzero(as_tuple=False).squeeze(-1)
1746
+
1747
+ for idx in valid_idxes:
1748
+
1749
+ s = int(all_spans[idx, 0])
1750
+ e = int(all_spans[idx, 1])
1751
+
1752
+ gold_set.add((s, e))
1753
+
1754
+ # no gold
1755
+ if len(gold_set) == 0:
1756
+ expand_dict[0] += 1
1757
+ return 0
1758
+
1759
+ start_labels = ensemble_start_logits.argmax(dim=-1) # (L,)
1760
+ end_labels = ensemble_end_logits.argmax(dim=-1) # (L,)
1761
+
1762
+ L = start_labels.shape[0]
1763
+
1764
+ base_pred = []
1765
+
1766
+ used_start = set()
1767
+ used_end = set()
1768
+
1769
+ for s in range(L):
1770
+
1771
+ s_label = start_labels[s].item()
1772
+
1773
+ if s_label == 0:
1774
+ continue
1775
+
1776
+ if s in used_start:
1777
+ continue
1778
+
1779
+ nearest_e = None
1780
+
1781
+ for e in range(s, L):
1782
+
1783
+ if e in used_end:
1784
+ continue
1785
+
1786
+ e_label = end_labels[e].item()
1787
+
1788
+ if e_label == s_label:
1789
+ nearest_e = e
1790
+ break
1791
+
1792
+ if nearest_e is None:
1793
+ continue
1794
+
1795
+ used_start.add(s)
1796
+ used_end.add(nearest_e)
1797
+
1798
+ base_pred.append((s, nearest_e))
1799
+
1800
+ for k in range(max_expand + 1):
1801
+
1802
+ pred_set = set()
1803
+
1804
+ for s, e in base_pred:
1805
+
1806
+ for ds in range(-k, k + 1):
1807
+ for de in range(-k, k + 1):
1808
+
1809
+ ns = s + ds
1810
+ ne = e + de
1811
+
1812
+ if ns > 0 and ne > 0 and ns <= ne:
1813
+ pred_set.add((ns, ne))
1814
+
1815
+ # recall = 1
1816
+ if gold_set.issubset(pred_set):
1817
+
1818
+ expand_dict[k] += 1
1819
+ return k
1820
+
1821
+ expand_dict[-1] = expand_dict.get(-1, 0) + 1
1822
+
1823
+ return -1
1824
+
1825
+ def test(network, state_dicts, test_loader, eval_fn, analyzer, device, id2label, tokenizer):
1826
+ if torch.cuda.device_count() > 1:
1827
+ network = DataParallelProxy(network)
1828
+ network = network.to(device)
1829
+ network.eval()
1830
+
1831
+ eval_types = ['Ent-I', 'Ent-C']
1832
+
1833
+ all_pred = {eval_type: [] for eval_type in eval_types}
1834
+ all_gold = {eval_type: [] for eval_type in eval_types}
1835
+
1836
+ list_input_ids = []
1837
+ expand_dict = {i: 0 for i in range(13)}
1838
+
1839
+ with torch.no_grad():
1840
+ for batch_idx, batch in enumerate(test_loader):
1841
+ input_ids = batch['input_ids'].to(device)
1842
+ attention_mask = batch['attention_mask'].to(device)
1843
+ all_spans = batch['all_spans'].to(device)
1844
+ all_labels = batch['all_labels'].to(device)
1845
+ gold_entities = batch['gold_entities']
1846
+
1847
+ B, _, _ = input_ids.shape
1848
+ list_input_ids.extend(input_ids.reshape(B, -1).tolist())
1849
+
1850
+ list_start_logits = []
1851
+ list_end_logits = []
1852
+ list_scores = []
1853
+ for sd in state_dicts:
1854
+ if torch.cuda.device_count() > 1:
1855
+ network.module.load_state_dict(sd)
1856
+ else:
1857
+ network.load_state_dict(sd)
1858
+
1859
+ start_logits, end_logits = network.pointer(input_ids, attention_mask)
1860
+ list_start_logits.append(start_logits)
1861
+ list_end_logits.append(end_logits)
1862
+
1863
+ ensemble_start_logits = torch.stack(list_start_logits, dim=0).mean(dim=0)
1864
+ ensemble_end_logits = torch.stack(list_end_logits, dim=0).mean(dim=0)
1865
+ spans, pred_labels = extract_spans_and_labels(ensemble_start_logits, ensemble_end_logits)
1866
+
1867
+ for sd in state_dicts:
1868
+ if torch.cuda.device_count() > 1:
1869
+ network.module.load_state_dict(sd)
1870
+ else:
1871
+ network.load_state_dict(sd)
1872
+
1873
+ expanded_spans, expanded_span_scores = network.span_score(input_ids, attention_mask, spans, use_max_r=True)
1874
+ list_scores.append(expanded_span_scores)
1875
+
1876
+ ensemble_scores = torch.stack(list_scores, dim=0).mean(dim=0)
1877
+ pred_spans = select_best_spans(expanded_spans, ensemble_scores)
1878
+
1879
+ pred_entities = extract_entities(input_ids.reshape(B, -1), pred_spans, pred_labels, id2label)
1880
+ pred_entities = shift_bidx(pred_entities, batch_idx)
1881
+ refactor_entities(pred_entities, all_pred)
1882
+
1883
+ gold_entities = shift_bidx(gold_entities, batch_idx)
1884
+ refactor_entities(gold_entities, all_gold)
1885
+
1886
+ for b in range(B):
1887
+ check_spans(
1888
+ all_spans[b], # (N, 2)
1889
+ all_labels[b], # (N,)
1890
+ ensemble_start_logits[b], # (L, C)
1891
+ ensemble_end_logits[b], # (L, C)
1892
+ expand_dict,
1893
+ max_expand=10
1894
+ )
1895
+
1896
+ # ===== GLOBAL EVAL =====
1897
+ final_score = {}
1898
+ for eval_type in eval_types:
1899
+ score = eval_fn(list_to_tuple(all_pred[eval_type]), list_to_tuple(all_gold[eval_type]))
1900
+ final_score[eval_type] = score
1901
+
1902
+ analyze_result = analyzer.analyze(list_to_tuple(all_pred['Ent-I']), list_to_tuple(all_gold['Ent-I']))
1903
+
1904
+ # ===== PREDICT =====
1905
+ predictions = []
1906
+ for input_ids in list_input_ids:
1907
+ predictions.append([tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)])
1908
+ for bidx, (ids, lb) in all_pred['Ent-C']:
1909
+ predictions[bidx].append((tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True), lb))
1910
+
1911
+ return final_score, analyze_result, predictions, expand_dict
1912
+
1913
+ # %% [code]
1914
+ with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f:
1915
+ data_train = json.load(f)
1916
+
1917
+ with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f:
1918
+ data_test = json.load(f)
1919
+
1920
+ print('Train:', len(data_train))
1921
+ print('Test:', len(data_test))
1922
+
1923
+ # %% [code]
1924
+ entity_types = ['O'] + sorted(list(set([e['label'] for d in data_train + data_test for e in d['entities']])))
1925
+ # bio_entity_type = ['O'] + [f'{prefix}-{ent}' for ent in entity_types for prefix in ['B', 'I']]
1926
+ label2id = {l: i for i, l in enumerate(entity_types)}
1927
+ id2label = {i: l for l, i in label2id.items()}
1928
+
1929
+ # %% [code]
1930
+ zero_entities_idxes = []
1931
+ for idx, d in enumerate(data_train):
1932
+ if len(d['entities']) == 0:
1933
+ zero_entities_idxes.append(idx)
1934
+
1935
+ n_zero_entities_samples = len(zero_entities_idxes)
1936
+ n_has_entities_samples = len(data_train) - n_zero_entities_samples
1937
+
1938
+ random.seed(42)
1939
+ k = min(int(n_has_entities_samples * zero_entities_rate), len(zero_entities_idxes))
1940
+ sampled_zero_entities_idxes = random.sample(zero_entities_idxes, k)
1941
+
1942
+ new_data_train = []
1943
+ for idx, d in enumerate(data_train):
1944
+ if len(d['entities']) == 0:
1945
+ if idx in sampled_zero_entities_idxes:
1946
+ new_data_train.append(d)
1947
+ else:
1948
+ new_data_train.append(d)
1949
+ data_train = new_data_train
1950
+
1951
+ print('Train:', len(data_train))
1952
+
1953
+ # %% [code]
1954
+ if debug_only:
1955
+ data_train = data_train[:10]
1956
+ data_test = data_test[:10]
1957
+
1958
+ print('Train:', len(data_train))
1959
+ print('Test:', len(data_test))
1960
+
1961
+ # %% [code]
1962
+ tokenizer = AutoTokenizer.from_pretrained(backbone_model_name)
1963
+
1964
+ # %% [code]
1965
+ print('Experiment name:', state_dict_save_name)
1966
+
1967
+ # %% [code]
1968
+ if not test_only:
1969
+ full_idxes = np.array(range(len(data_train)))
1970
+ training_logs, best_models, last_models = [], [], []
1971
+ start_training_time = time.time()
1972
+ for seed in SEEDS:
1973
+ kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed)
1974
+ for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)):
1975
+ if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx:
1976
+ continue
1977
+ set_seed(seed)
1978
+
1979
+ train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx]
1980
+
1981
+ trainset = KLTNDataset(data_train, train_idxes, label2id, tokenizer, **train_memory_params)
1982
+ valset = KLTNDataset(data_train, val_idxes, label2id, tokenizer, **val_memory_params)
1983
+
1984
+ generator = torch.Generator()
1985
+ generator.manual_seed(seed)
1986
+ train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params)
1987
+ val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params)
1988
+
1989
+ my_model = IEModel(
1990
+ num_labels=len(label2id),
1991
+ **model_params
1992
+ )
1993
+ total_params = sum(p.numel() for p in my_model.parameters())
1994
+ print(f"Total params: {total_params:,}")
1995
+
1996
+ # optimizer, scheduler = configure_optimizers(my_model, optim_params, scheduler_params)
1997
+ encoder_params = set(map(id, my_model.pointer.encoder.parameters())) | set(map(id, my_model.span_score.encoder.parameters()))
1998
+ other_params = [
1999
+ p for p in my_model.parameters()
2000
+ if id(p) not in encoder_params
2001
+ ]
2002
+
2003
+ optimizer = optim.AdamW([
2004
+ {"params": my_model.pointer.encoder.parameters(), "lr": 2e-5},
2005
+ {"params": my_model.span_score.encoder.parameters(), "lr": 2e-5},
2006
+ {"params": other_params}
2007
+ ], lr=5e-4)
2008
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
2009
+
2010
+ loss_fn = CustomLoss(
2011
+ **loss_func_params
2012
+ )
2013
+ eval_fn = CustomEvalFn(**eval_func_params)
2014
+ trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}'
2015
+ trainer = Trainer(**trainer_params)
2016
+
2017
+ print(f'Start Training Fold {fold_idx}...')
2018
+ training_log, best_model, last_model = trainer.fit(
2019
+ my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, eval_fn,
2020
+ start_epoch=1, start_training_time=start_training_time, id2label=id2label
2021
+ )
2022
+
2023
+ training_logs.append(training_log)
2024
+ best_models.append(best_model)
2025
+ last_models.append(last_model)
2026
+
2027
+ # %% [code]
2028
+ def load_all_state_dicts(folder):
2029
+ files = []
2030
+
2031
+ for file in os.listdir(folder):
2032
+ if file.endswith(".pt") or file.endswith(".pth"):
2033
+ m = re.search(r"f(\d+)", file) # tìm f<số>
2034
+ if m:
2035
+ fold = int(m.group(1))
2036
+ files.append((fold, file))
2037
+
2038
+ # sort theo fold
2039
+ files.sort(key=lambda x: x[0])
2040
+
2041
+ state_dicts = []
2042
+ for fold, file in files:
2043
+ path = os.path.join(folder, file)
2044
+ print(f"Loading fold {fold}: {file}")
2045
+ state_dict = torch.load(path, map_location="cpu")
2046
+ state_dicts.append(state_dict)
2047
+
2048
+ return state_dicts
2049
+
2050
+ if test_only:
2051
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"])
2052
+ get_ipython().system('rm -rf .cache .gitattributes')
2053
+
2054
+ best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s")
2055
+ last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts")
2056
+
2057
+ # %% [code]
2058
+ os.makedirs(f'{checkpoints_dir}/results', exist_ok=True)
2059
+ testset = KLTNDataset(data_test, range(len(data_test)), label2id, tokenizer, **val_memory_params)
2060
+ generator = torch.Generator()
2061
+ test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params)
2062
+ eval_fn = CustomEvalFn(**eval_func_params)
2063
+ analyzer = SpanErrorAnalyzer()
2064
+ my_model = IEModel(
2065
+ num_labels=len(label2id),
2066
+ **model_params
2067
+ )
2068
+ total_params = sum(p.numel() for p in my_model.parameters())
2069
+ print(f"Total params: {total_params:,}")
2070
+
2071
+ # %% [code]
2072
+ start_time = time.time()
2073
+ result_test = None
2074
+ analyze_result = None
2075
+
2076
+ best_score, best_analyze_result, best_pred_test, expand_dict = test(my_model, best_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
2077
+ last_score, last_analyze_result, last_pred_test, _ = test(my_model, last_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
2078
+
2079
+ result_test = {"Best model": best_score, "Last model": last_score}
2080
+ analyze_result = {"Best model": best_analyze_result, "Last model": last_analyze_result}
2081
+ analyze_result_sumary = {"Best model": best_analyze_result['summary'], "Last model": last_analyze_result['summary']}
2082
+ pred_test = {"Best model": best_pred_test, "Last model": last_pred_test}
2083
+
2084
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f:
2085
+ json.dump(result_test, f, ensure_ascii=False, indent=2)
2086
+
2087
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_error_analyze_result.json", "w", encoding="utf-8") as f:
2088
+ json.dump(analyze_result, f, ensure_ascii=False, indent=2)
2089
+
2090
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_pred_test.json", "w", encoding="utf-8") as f:
2091
+ json.dump(pred_test, f, ensure_ascii=False, indent=2)
2092
+
2093
+ print('Test:', time.time() - start_time, 's --> Done!')
2094
+ print(json.dumps(analyze_result_sumary, ensure_ascii=False, indent=4))
2095
+
2096
+ # %% [code]
2097
+ expand_dict
2098
+
2099
+ # %% [code]
2100
+ expand_dict_sum = sum(list(expand_dict.values()))
2101
+ {key: value / expand_dict_sum for key, value in expand_dict.items()}
2102
+
2103
+ # %% [code]
2104
+ best_pred_test[:10]
2105
+
2106
+ # %% [code]
2107
+ last_pred_test[:10]
2108
+
2109
+ # %% [code]
2110
+ def dict_to_df(data):
2111
+ row_tuples = []
2112
+ row_values = []
2113
+
2114
+ metrics = ["precision", "recall", "f1"]
2115
+
2116
+ # Lấy model đầu tiên
2117
+ first_model = next(iter(data.values()))
2118
+
2119
+ # eval_keys
2120
+ eval_keys = list(first_model.keys())
2121
+
2122
+ for eval_key in eval_keys:
2123
+ row_tuples.append(eval_key)
2124
+ row = {}
2125
+
2126
+ for model_name, model_data in data.items():
2127
+ for metric in metrics:
2128
+ row[(model_name, metric)] = model_data[eval_key][metric]
2129
+
2130
+ row_values.append(row)
2131
+
2132
+ # ===== DataFrame =====
2133
+ df = pd.DataFrame(row_values)
2134
+
2135
+ # MultiIndex columns
2136
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
2137
+
2138
+ # Index
2139
+ df.index = pd.Index(row_tuples, name="evaluation")
2140
+
2141
+ # ===== Sort =====
2142
+ sort_keys = []
2143
+ if ("Best model", "f1") in df.columns:
2144
+ sort_keys.append(("Best model", "f1"))
2145
+ if ("Last model", "f1") in df.columns:
2146
+ sort_keys.append(("Last model", "f1"))
2147
+
2148
+ if sort_keys:
2149
+ df = df.sort_values(by=sort_keys, ascending=False)
2150
+
2151
+ return df
2152
+
2153
+ result_test_df = dict_to_df(result_test)
2154
+ result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx")
2155
+ result_test_df
2156
+
2157
+ # %% [code]
2158
+ key = ("Best model", "f1")
2159
+ result_test_df_best = result_test_df.sort_values(by=key, ascending=False).groupby(level="evaluation").head(1)
2160
+ result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx")
2161
+ result_test_df_best
2162
+
2163
+ # %% [code]
2164
+ def get_avg_best_score(logs):
2165
+ return float(np.mean([list(log.values())[-1]['best_score'] for log in logs]))
2166
+
2167
+ def get_avg_log(logs, epochs):
2168
+ avg_log = {}
2169
+
2170
+ for epoch in range(1, epochs + 1):
2171
+ val_score = 0.0
2172
+ train_loss = 0.0
2173
+ n_eval = 0
2174
+
2175
+ for idx in range(len(logs)):
2176
+ log = logs[idx].get(epoch, logs[idx].get(str(epoch)))
2177
+ if log is None:
2178
+ continue
2179
+
2180
+ val_score += log.get('val_score', 0.0)
2181
+ train_loss += log.get('train_loss', 0.0)
2182
+ n_eval += 1
2183
+
2184
+ if n_eval == 0:
2185
+ continue
2186
+
2187
+ avg_log[epoch] = {
2188
+ 'train_loss': train_loss / n_eval,
2189
+ 'val_score': val_score / n_eval if val_score != 0 else float('inf')
2190
+ }
2191
+
2192
+ return avg_log
2193
+
2194
+ def parse_label_key(label: str):
2195
+ try:
2196
+ first = float(label.split('_', 1)[0]) # số đầu: trước dấu _
2197
+ last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0])
2198
+ return first, last
2199
+ except:
2200
+ return (0, 0)
2201
+
2202
+ def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)):
2203
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
2204
+
2205
+ # ===== Plot Train Loss =====
2206
+ for name, log in logs_dict.items():
2207
+ epochs = sorted(log.keys())
2208
+ train_loss = [log[e]['train_loss'] for e in epochs]
2209
+ axes[0].plot(epochs, train_loss, label=name)
2210
+
2211
+ axes[0].set_xlabel('Epoch')
2212
+ axes[0].set_ylabel('Train Loss')
2213
+ axes[0].set_title('Training Loss')
2214
+ axes[0].grid(True)
2215
+
2216
+ # ===== Plot Validation Score =====
2217
+ for name, log in logs_dict.items():
2218
+ epochs = sorted(log.keys())
2219
+ val_score = [log[e]['val_score'] for e in epochs]
2220
+ axes[1].plot(epochs, val_score, label=name)
2221
+
2222
+ axes[1].set_xlabel('Epoch')
2223
+ axes[1].set_ylabel('Validation Score')
2224
+ axes[1].set_title('Validation Score')
2225
+ axes[1].grid(True)
2226
+
2227
+ # ===== Shared Legend =====
2228
+ handles, labels = axes[0].get_legend_handles_labels()
2229
+ pairs = list(zip(handles, labels))
2230
+ pairs_sorted = sorted(
2231
+ pairs,
2232
+ key=lambda x: parse_label_key(x[1])
2233
+ )
2234
+ handles_sorted, labels_sorted = zip(*pairs_sorted)
2235
+
2236
+ axes[0].legend(
2237
+ handles_sorted,
2238
+ labels_sorted,
2239
+ loc='center left',
2240
+ bbox_to_anchor=(1.01, 0.5),
2241
+ borderaxespad=0.
2242
+ )
2243
+
2244
+ plt.tight_layout(rect=[0, 0, 1, 1])
2245
+
2246
+ if save_path is not None:
2247
+ os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None
2248
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
2249
+
2250
+ plt.show()
2251
+
2252
+ # %% [code]
2253
+ # if not test_only:
2254
+ # snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*entities*.json"])
2255
+ # !rm -rf .cache .gitattributes
2256
+
2257
+ # %% [code]
2258
+ if not test_only:
2259
+ experiments = {}
2260
+ for experiment in os.listdir(pretrained_dir):
2261
+ if '.virtual_documents' in experiment:
2262
+ continue
2263
+ experiment_logs = []
2264
+ try:
2265
+ for seed in SEEDS:
2266
+ for fold_idx in range(nfolds):
2267
+ with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f:
2268
+ experiment_log = json.load(f)
2269
+ experiment_logs.append(experiment_log)
2270
+ except:
2271
+ pass
2272
+ experiments[experiment] = get_avg_log(experiment_logs, 1000)
2273
+ experiments[state_dict_save_name] = get_avg_log(training_logs, 1000)
2274
+
2275
+ # %% [code]
2276
+ if not test_only:
2277
+ score = get_avg_best_score(training_logs)
2278
+ print(state_dict_save_name, score)
2279
+
2280
+ # %% [code]
2281
+ if not test_only:
2282
+ plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5))
2283
+
4.2_2_clone_4.6/lasts/4.2_2_clone_4.6_s26092004_f0_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:318d57cfaff65518133fcc9b4c6aa4aa5bde0711804b06eee33c1175c3368361
3
+ size 1094347339
4.2_2_clone_4.6/logs/4.2_2_clone_4.6_log_plot.jpg ADDED

Git LFS Details

  • SHA256: e852634505c3e29dcdacd005c72093f1aa4a69efb608c3e99ac96ca8f0070e14
  • Pointer size: 131 Bytes
  • Size of remote file: 439 kB
4.2_2_clone_4.6/logs/4.2_2_clone_4.6_s26092004_f0_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 2e-05, 0.0005], "train_loss": 0.15576171875, "total": 0.1558133035215204, "token_margin_loss": 0.05345164896590274, "span_loss": 0.10522638345444382, "start_margin": 0.02552054220234768, "end_margin": 0.02780883174958077}, "2": {"lr": [1.988303923565381e-05, 1.988303923565381e-05, 0.0004969282409784868], "train_loss": 0.11346435546875, "total": 0.11347121296813863, "token_margin_loss": 0.03549468977082169, "span_loss": 0.07448295136948016, "start_margin": 0.017345584125209614, "end_margin": 0.01879541643376188}, "3": {"lr": [1.9535036904803962e-05, 1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 0.10601806640625, "total": 0.10599496925656791, "token_margin_loss": 0.034219536053661265, "span_loss": 0.06850894354387926, "start_margin": 0.01665560368921185, "end_margin": 0.017703675237562885}, "4": {"lr": [1.8964561979789496e-05, 1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 0.09228515625, "total": 0.09230016769144773, "token_margin_loss": 0.03318893236444941, "span_loss": 0.057364449413079936, "start_margin": 0.016271310788149803, "end_margin": 0.017127235885969816}, "5": {"lr": [1.8185661446562005e-05, 1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 0.08563232421875, "total": 0.08566238121855785, "token_margin_loss": 0.03236794298490777, "span_loss": 0.05187954164337619, "start_margin": 0.01590448574622694, "end_margin": 0.01655953046394634}, "6": {"lr": [1.7217514421272206e-05, 1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 0.0771484375, "total": 0.0771380659586361, "token_margin_loss": 0.03154695360536613, "span_loss": 0.04485746226942426, "start_margin": 0.015319312465064282, "end_margin": 0.016175237562884293}, "7": {"lr": [1.60839598967785e-05, 1.60839598967785e-05, 0.00039715242044697206], "train_loss": 0.071044921875, "total": 0.07105925097820011, "token_margin_loss": 0.03079583566238122, "span_loss": 0.040211011738401345, "start_margin": 0.015031092789267748, "end_margin": 0.015764742873113472}, "8": {"lr": [1.4812909747525698e-05, 1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 0.066162109375, "total": 0.06613331470095025, "token_margin_loss": 0.02981763555058692, "span_loss": 0.03612353269983231, "start_margin": 0.014603130240357741, "end_margin": 0.015231973169368362, "val_score": 0.6385244907883434, "best_score": 0.6385244907883434, "new_best_model": true, "precision": 0.6936645134825964, "recall": 0.5935691103168296, "f1": 0.6385244907883434}, "9": {"lr": [1.3435661446562005e-05, 1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 0.0606689453125, "total": 0.06068334264952487, "token_margin_loss": 0.028944242593627725, "span_loss": 0.031372275013974285, "start_margin": 0.014096562325321409, "end_margin": 0.014760340972610397, "val_score": 0.6403727920350234, "best_score": 0.6403727920350234, "new_best_model": true, "precision": 0.6963604033207965, "recall": 0.594759423623009, "f1": 0.6403727920350234}, "10": {"lr": [1.1986127417882198e-05, 1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 0.05548095703125, "total": 0.055477920626048075, "token_margin_loss": 0.027913638904415873, "span_loss": 0.027179988820570152, "start_margin": 0.013511389044158748, "end_margin": 0.014297442705422023, "val_score": 0.6392255787070485, "best_score": 0.6403727920350234, "new_best_model": false, "precision": 0.6965781037179611, "recall": 0.5927149671752961, "f1": 0.6392255787070485}, "11": {"lr": [1.0500000000000003e-05, 1.0500000000000003e-05, 0.0002505], "train_loss": 0.050628662109375, "total": 0.05062185578535495, "token_margin_loss": 0.027459474566797093, "span_loss": 0.02386109558412521, "start_margin": 0.013354178311906093, "end_margin": 0.014140231973169368, "val_score": 0.6381648639893853, "best_score": 0.6403727920350234, "new_best_model": false, "precision": 0.6917411055010741, "recall": 0.5943519122939683, "f1": 0.6381648639893853}, "12": {"lr": [9.013872582117811e-06, 9.013872582117811e-06, 0.00021146960097246246], "train_loss": 0.047454833984375, "total": 0.04744270542202347, "token_margin_loss": 0.02681316377864729, "span_loss": 0.02125838457238681, "start_margin": 0.012943683622135271, "end_margin": 0.013721003353828955, "val_score": 0.6382403697944274, "best_score": 0.6403727920350234, "new_best_model": false, "precision": 0.6900393046014547, "recall": 0.5956313292215475, "f1": 0.6382403697944274}, "13": {"lr": [7.564338553438001e-06, 7.564338553438001e-06, 0.00017340025990345064], "train_loss": 0.042510986328125, "total": 0.04251676914477362, "token_margin_loss": 0.02597470653996646, "span_loss": 0.017563932364449412, "start_margin": 0.012541922861934042, "end_margin": 0.013310508664058134, "val_score": 0.6379628637474125, "best_score": 0.6403727920350234, "new_best_model": false, "precision": 0.6886392683400498, "recall": 0.5961530485844763, "f1": 0.6379628637474125}, "14": {"lr": [6.1870902524743065e-06, 6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 0.041259765625, "total": 0.041259083286752374, "token_margin_loss": 0.025241056456120736, "span_loss": 0.017066098378982673, "start_margin": 0.012070290665176077, "end_margin": 0.012978619340413639, "val_score": 0.6384384557419203, "best_score": 0.6403727920350234, "new_best_model": false, "precision": 0.6945229568751328, "recall": 0.5927117533760684, "f1": 0.6384384557419203}, "15": {"lr": [4.916040103221507e-06, 4.916040103221507e-06, 0.00010384757955302797], "train_loss": 0.03759765625, "total": 0.03759083286752376, "token_margin_loss": 0.02457727780883175, "span_loss": 0.013223169368362214, "start_margin": 0.011782070989379542, "end_margin": 0.012472051425377306, "val_score": 0.6388037098636425, "best_score": 0.6403727920350234, "new_best_model": false, "precision": 0.6961475479130724, "recall": 0.5921893856038067, "f1": 0.6388037098636425}, "16": {"lr": [3.7824855787278e-06, 3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 0.0362548828125, "total": 0.03626327557294578, "token_margin_loss": 0.02394843487982113, "span_loss": 0.012341042481833427, "start_margin": 0.011397778088317496, "end_margin": 0.012279904974846283, "val_score": 0.6391129270956746, "best_score": 0.6403727920350234, "new_best_model": false, "precision": 0.6944115187441752, "recall": 0.5940124062359109, "f1": 0.6391129270956746}, "17": {"lr": [2.814338553438001e-06, 2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 0.034942626953125, "total": 0.03495318613750699, "token_margin_loss": 0.02340693124650643, "span_loss": 0.011362842370039128, "start_margin": 0.011196897708216882, "end_margin": 0.011982951369480157, "val_score": 0.6383061159007754, "best_score": 0.6403727920350234, "new_best_model": false, "precision": 0.6948955158852926, "recall": 0.5923005609982788, "f1": 0.6383061159007754}, "18": {"lr": [2.0354380202105066e-06, 2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 0.033355712890625, "total": 0.03334614309670207, "token_margin_loss": 0.0230051704863052, "span_loss": 0.009668460033538289, "start_margin": 0.011030953046394634, "end_margin": 0.011755869200670765, "val_score": 0.6385992596141453, "best_score": 0.6403727920350234, "new_best_model": false, "precision": 0.6945583416927207, "recall": 0.5930287509644723, "f1": 0.6385992596141453}}
4.2_2_clone_4.6/r1s/4.2_2_clone_4.6_s26092004_f0_r1_vs0.64037_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cafd73d6958aaadd39106dd750581a9682bfece8ea6acead3cadb8edec2a45d6
3
+ size 1094350795
4.2_2_clone_4.6/results/4.2_2_clone_4.6_error_analyze_result.json ADDED
The diff for this file is too large to render. See raw diff
 
4.2_2_clone_4.6/results/4.2_2_clone_4.6_pred_test.json ADDED
The diff for this file is too large to render. See raw diff
 
4.2_2_clone_4.6/results/4.2_2_clone_4.6_test.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Best model": {
3
+ "Ent-I": {
4
+ "precision": 0.7757224168117604,
5
+ "recall": 0.6587043405514836,
6
+ "f1": 0.7124403066686483
7
+ },
8
+ "Ent-C": {
9
+ "precision": 0.6243589743583585,
10
+ "recall": 0.5878911690959346,
11
+ "f1": 0.6055765409889282
12
+ }
13
+ },
14
+ "Last model": {
15
+ "Ent-I": {
16
+ "precision": 0.770265835246942,
17
+ "recall": 0.6544288502642862,
18
+ "f1": 0.7076381859871164
19
+ },
20
+ "Ent-C": {
21
+ "precision": 0.6215788953654563,
22
+ "recall": 0.5841768037881101,
23
+ "f1": 0.6022977451239089
24
+ }
25
+ }
26
+ }
4.2_2_clone_4.6/results/4.2_2_clone_4.6_test_df.xlsx ADDED
Binary file (5.28 kB). View file
 
4.2_2_clone_4.6/results/4.2_2_clone_4.6_test_df_best.xlsx ADDED
Binary file (5.28 kB). View file