SS3M commited on
Commit
58d0cb8
·
verified ·
1 Parent(s): a023cd9

Upload 0_token_base_issue_1's state dict

Browse files
0_token_base_issue_1/0_token_base_issue_1.py ADDED
@@ -0,0 +1,2086 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [code]
2
+ get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch]')
3
+
4
+ # %% [code]
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
12
+ import torch.nn.functional as F
13
+ import albumentations as albu
14
+ from transformers import AutoTokenizer, AutoModel
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from positional_encodings.torch_encodings import PositionalEncoding1D
18
+
19
+ from sklearn.metrics import f1_score
20
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
21
+ from scipy.spatial.transform import Rotation as R
22
+ from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold
23
+ from sklearn.metrics import precision_recall_fscore_support
24
+ from timm.utils import ModelEmaV3
25
+ import timm
26
+
27
+ import os
28
+ import gc
29
+ import json
30
+ from pathlib import Path
31
+ import pickle
32
+ from tqdm.auto import tqdm
33
+ import copy
34
+ import numpy as np
35
+ import pandas as pd
36
+ import polars as pl
37
+ from PIL import Image
38
+ import time
39
+ from tqdm import tqdm
40
+ from matplotlib import pyplot as plt
41
+ import seaborn as sns
42
+ from multiprocessing import Manager as MemoryManager
43
+ from functools import lru_cache
44
+ import shutil
45
+ import glob
46
+ import cv2
47
+ import random
48
+ import re
49
+ import joblib
50
+ import math
51
+ from huggingface_hub import HfApi, snapshot_download
52
+ import evaluate
53
+ from underthesea import word_tokenize as vi_tokenize_tool
54
+ import spacy
55
+ en_tokenize_tool = spacy.load("en_core_web_sm")
56
+ from collections import defaultdict, Counter
57
+
58
+ # %% [code]
59
+ # Global config
60
+ SEEDS = [26092004]
61
+ topk = 1
62
+ nfolds = 5
63
+ only_fold_idx = 0
64
+ test_only = 0
65
+ debug_only = 0
66
+
67
+ # Config thư mục
68
+ dataset = 'kltn/only_issues' # vhe, bkee, casie, kltn/only_issues, kltn/only_actions
69
+ root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' ## Thư mục chứa file train, val, test
70
+ train_dir = f'{root_dir}'
71
+ # val_dir = f'{root_dir}/val'
72
+ test_dir = f'{root_dir}'
73
+
74
+ # Config checkpoints
75
+
76
+ # Config training
77
+ epochs = 18 if not debug_only else 2
78
+ batch_size = 32
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+ # # Thêm biến toàn cục nào đó vào đây
81
+ repo_name = 'SS3M/kltn-experiments'
82
+ state_dict_save_name = "0_token_base_issue_1"
83
+ checkpoints_dir = state_dict_save_name
84
+ pretrained_dir = "/kaggle/working"
85
+ os.makedirs(f'{checkpoints_dir}', exist_ok=True)
86
+
87
+ backbone_model_name = "bert-base-uncased" if dataset == "casie" else "vinai/phobert-base"
88
+ word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == "casie" else vi_tokenize_tool(text)
89
+ max_len_dict = {
90
+ 'kltn/only_issues': 52,
91
+ 'kltn/only_actions': 69,
92
+ 'vhe': 51,
93
+ 'bkee': 62,
94
+ 'casie': 40,
95
+ }
96
+ zero_events_rate_dict = {
97
+ 'kltn/only_issues': 0,
98
+ 'kltn/only_actions': 0.2,
99
+ 'vhe': 1000, # mean keep all zero-events samples
100
+ 'bkee': 1000,
101
+ 'casie': 1000,
102
+ }
103
+
104
+ max_len = max_len_dict[dataset]
105
+ max_n_parts = 1
106
+ max_span_len = 14
107
+ zero_events_rate = zero_events_rate_dict[dataset]
108
+
109
+ # Trainer
110
+ trainer_params = {
111
+ "training_time": "00:11:30:00",
112
+ "eval_mode": "max",
113
+ "topk": topk,
114
+ "save_name": state_dict_save_name,
115
+ "save_best": True,
116
+ "save_last": True,
117
+ "device": device,
118
+ "logging": True,
119
+ "logging_file": True,
120
+ "checkpoints_dir": checkpoints_dir,
121
+ "early_stopping": 30,
122
+ "eval_from_ratio": 0.4,
123
+ "eval_every": 1,
124
+ "schedule_in_step": False,
125
+ "use_ema": True,
126
+ "ema_from_ratio": 0.3,
127
+ "ema_decay": 0.9995,
128
+ "max_grad_norm": 200.0,
129
+ "return_best": True,
130
+ "return_last": True,
131
+ }
132
+
133
+ # Memory
134
+ train_memory_params = {
135
+ 'max_len': max_len,
136
+ 'max_n_parts': max_n_parts,
137
+ }
138
+ val_memory_params = {
139
+ 'max_len': max_len,
140
+ 'max_n_parts': max_n_parts,
141
+ }
142
+
143
+ # Data Loader
144
+ def seed_worker(worker_id):
145
+ worker_seed = torch.initial_seed() % 2**32
146
+ np.random.seed(worker_seed)
147
+ random.seed(worker_seed)
148
+
149
+ train_loader_params = {
150
+ 'batch_size': batch_size,
151
+ 'shuffle': True,
152
+ 'pin_memory':True,
153
+ 'num_workers': 2,
154
+ 'drop_last': False,
155
+ 'worker_init_fn': seed_worker,
156
+ 'persistent_workers': False,
157
+ }
158
+ val_loader_params = {
159
+ 'batch_size': batch_size,
160
+ 'shuffle': False,
161
+ 'pin_memory':True,
162
+ 'num_workers': 1,
163
+ 'drop_last': False,
164
+ 'worker_init_fn': seed_worker,
165
+ 'persistent_workers': False,
166
+ }
167
+
168
+ # Model
169
+ model_params = {
170
+ 'backbone_model_name': backbone_model_name,
171
+ }
172
+
173
+ # Loss Func
174
+ loss_func_params = {
175
+ 'lambda_trg_ce': 1.0,
176
+ 'lambda_arg_ce': 1.0,
177
+ }
178
+ eval_func_params = {}
179
+
180
+ # Optim
181
+ optim_params = {
182
+ 'name': 'AdamW',
183
+ 'lr': 1e-4,
184
+ 'weight_decay': 1e-4,
185
+ }
186
+ scheduler_params = {
187
+ 'name': 'CosineAnnealingLR',
188
+ 'T_max': 20, # Số epoch để hoàn thành một chu kỳ giảm LR
189
+ 'eta_min': 1e-6 # Learning rate nhỏ nhất trong chu kỳ
190
+ }
191
+
192
+ # %% [code]
193
+ def set_seed(seed=42):
194
+ random.seed(seed)
195
+ np.random.seed(seed)
196
+ torch.manual_seed(seed)
197
+ torch.cuda.manual_seed(seed)
198
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
199
+ torch.use_deterministic_algorithms(False)
200
+ torch.backends.cudnn.deterministic = True
201
+ torch.backends.cudnn.benchmark = False
202
+ os.environ['PYTHONHASHSEED'] = str(seed)
203
+
204
+ # %% [code]
205
+ class CustomLoss(nn.Module):
206
+ def __init__(
207
+ self,
208
+ lambda_trg_ce=1.0,
209
+ lambda_arg_ce=1.0,
210
+ ):
211
+ super().__init__()
212
+ self.lambda_trg_ce = lambda_trg_ce
213
+ self.lambda_arg_ce = lambda_arg_ce
214
+ self.ce = nn.CrossEntropyLoss(ignore_index=-100)
215
+
216
+ def forward(
217
+ self,
218
+ trg_logits, trg_labels,
219
+ trg_arg_logits, pred_trg_arg_labels
220
+ ):
221
+ device = trg_logits.device
222
+
223
+ # ===== TRG CE =====
224
+ B, N, C = trg_logits.shape
225
+ trg_logits_flat = trg_logits.view(B * N, C)
226
+ trg_labels_flat = trg_labels.view(-1)
227
+
228
+ trg_loss = self.ce(trg_logits_flat, trg_labels_flat) # (B*N,)
229
+
230
+ # ===== ARG CE =====
231
+ B, K, M, C = trg_arg_logits.shape
232
+ arg_logits_flat = trg_arg_logits.view(B * K * M, C)
233
+ arg_labels_flat = pred_trg_arg_labels.view(-1)
234
+
235
+ arg_mask = (arg_labels_flat != -100)
236
+
237
+ if arg_mask.sum() == 0:
238
+ arg_loss = torch.tensor(0.0, device=device)
239
+ else:
240
+ arg_loss = self.ce(arg_logits_flat, arg_labels_flat) # (B*K*M,)
241
+
242
+ # ===== TOTAL =====
243
+ total_loss = (
244
+ self.lambda_trg_ce * trg_loss +
245
+ self.lambda_arg_ce * arg_loss
246
+ )
247
+
248
+ return {
249
+ "total": total_loss,
250
+ "trg_loss": trg_loss,
251
+ "arg_loss": arg_loss,
252
+ }
253
+
254
+ # %% [code]
255
+ ## Viết eval_fn vào đây
256
+
257
+ # Bỏ hết eval_fn và trọng số vào đây
258
+ class CustomEvalFn(nn.Module):
259
+ def __init__(self):
260
+ super().__init__()
261
+
262
+ def compute_f1(self, tp, fp, fn):
263
+ precision = tp / (tp + fp + 1e-8)
264
+ recall = tp / (tp + fn + 1e-8)
265
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
266
+ return precision, recall, f1
267
+
268
+ def forward(self, pred, gold):
269
+ pred_set = set(pred)
270
+ gold_set = set(gold)
271
+
272
+ tp = len(pred_set & gold_set)
273
+ fp = len(pred_set - gold_set)
274
+ fn = len(gold_set - pred_set)
275
+
276
+ precision, recall, f1 = self.compute_f1(tp, fp, fn)
277
+
278
+ return {
279
+ f"precision": precision,
280
+ f"recall": recall,
281
+ f"f1": f1,
282
+ }
283
+
284
+ class SpanErrorAnalyzer:
285
+ def __init__(self, pad_token_id=0):
286
+ self.pad_token_id = pad_token_id
287
+
288
+ # ===== helper =====
289
+ def _to_set(self, data):
290
+ """
291
+ data: list of (b, tuple(ids))
292
+ -> dict[b] = set(tuple(ids))
293
+ """
294
+ res = defaultdict(set)
295
+ for b, ids in data:
296
+ ids = tuple([i for i in ids if i != self.pad_token_id])
297
+ if len(ids) > 0:
298
+ res[b].add(ids)
299
+ return res
300
+
301
+ def _iou(self, a, b):
302
+ """
303
+ a, b: tuple(ids)
304
+ """
305
+ set_a, set_b = set(a), set(b)
306
+ inter = len(set_a & set_b)
307
+ union = len(set_a | set_b)
308
+ if union == 0:
309
+ return 0.0
310
+ return inter / union
311
+
312
+ def _boundary_error(self, pred, gold):
313
+ """
314
+ đo lệch boundary dựa trên overlap prefix/suffix
315
+ """
316
+ # left match
317
+ left = 0
318
+ for i in range(min(len(pred), len(gold))):
319
+ if pred[i] == gold[i]:
320
+ left += 1
321
+ else:
322
+ break
323
+
324
+ # right match
325
+ right = 0
326
+ for i in range(1, min(len(pred), len(gold)) + 1):
327
+ if pred[-i] == gold[-i]:
328
+ right += 1
329
+ else:
330
+ break
331
+
332
+ return {
333
+ "left_match": left,
334
+ "right_match": right,
335
+ "pred_len": len(pred),
336
+ "gold_len": len(gold),
337
+ }
338
+
339
+ # ===== main =====
340
+ def analyze(self, preds, golds):
341
+ pred_map = self._to_set(preds)
342
+ gold_map = self._to_set(golds)
343
+
344
+ all_batches = set(pred_map.keys()) | set(gold_map.keys())
345
+
346
+ stats = Counter()
347
+
348
+ detailed_errors = []
349
+
350
+ for b in all_batches:
351
+ pset = pred_map.get(b, set())
352
+ gset = gold_map.get(b, set())
353
+
354
+ matched_gold = set()
355
+
356
+ # ===== check predictions =====
357
+ for p in pset:
358
+ if p in gset:
359
+ stats["exact_match"] += 1
360
+ matched_gold.add(p)
361
+ else:
362
+ # tìm gold gần nhất
363
+ best_iou = 0
364
+ best_g = None
365
+
366
+ for g in gset:
367
+ iou = self._iou(p, g)
368
+ if iou > best_iou:
369
+ best_iou = iou
370
+ best_g = g
371
+
372
+ if best_iou > 0:
373
+ stats["partial_match"] += 1
374
+
375
+ boundary = self._boundary_error(p, best_g)
376
+
377
+ detailed_errors.append({
378
+ "type": "boundary_error",
379
+ "batch": b,
380
+ "pred": p,
381
+ "gold": best_g,
382
+ "iou": best_iou,
383
+ **boundary
384
+ })
385
+ else:
386
+ if b not in gold_map:
387
+ stats["no_event_sample"] += 1
388
+ err_type = "no_event_sample"
389
+ else:
390
+ stats["completely_wrong"] += 1
391
+ err_type = "completely_wrong"
392
+
393
+ detailed_errors.append({
394
+ "type": err_type,
395
+ "batch": b,
396
+ "pred": p
397
+ })
398
+
399
+ # ===== check missing =====
400
+ for g in gset:
401
+ if g not in matched_gold:
402
+ # check if any pred overlaps
403
+ overlap = any(self._iou(p, g) > 0 for p in pset)
404
+
405
+ if overlap:
406
+ stats["miss_with_overlap"] += 1
407
+ else:
408
+ stats["miss"] += 1
409
+
410
+ detailed_errors.append({
411
+ "type": "miss",
412
+ "batch": b,
413
+ "gold": g
414
+ })
415
+
416
+ return {
417
+ "summary": {
418
+ "exact_match": (stats["exact_match"], stats["exact_match"] / len(preds)),
419
+ "partial_match": (stats["partial_match"], stats["partial_match"] / len(preds)),
420
+ "no_event_sample": (stats["no_event_sample"], stats["no_event_sample"] / len(preds)),
421
+ "completely_wrong": (stats["completely_wrong"], stats["completely_wrong"] / len(preds)),
422
+ "miss": (stats["miss"], stats["miss"] / len(golds)),
423
+ "miss_with_overlap": (stats["miss_with_overlap"], stats["miss_with_overlap"] / len(golds)),
424
+ },
425
+ "details": detailed_errors
426
+ }
427
+
428
+ # %% [code]
429
+ ## Viết cấu trúc model vào đây
430
+ def fix_bio_ids_batch(label_ids):
431
+ """
432
+ label_ids: (B, L)
433
+ return: (B, L) fixed
434
+ """
435
+ B, L = label_ids.shape
436
+ fixed = label_ids.clone()
437
+
438
+ for b in range(B):
439
+ for i in range(L):
440
+ tag = fixed[b, i].item()
441
+
442
+ if tag == 0:
443
+ continue
444
+
445
+ # I- (even)
446
+ if tag % 2 == 0:
447
+ if i == 0 or fixed[b, i-1].item() == 0:
448
+ fixed[b, i] = tag - 1
449
+ else:
450
+ prev_tag = fixed[b, i-1].item()
451
+
452
+ if prev_tag == 0:
453
+ fixed[b, i] = tag - 1
454
+ else:
455
+ prev_type = (prev_tag - 1) // 2
456
+ curr_type = (tag - 1) // 2
457
+
458
+ if prev_type != curr_type:
459
+ fixed[b, i] = tag - 1
460
+
461
+ return fixed
462
+
463
+ def extract_trigger_spans_batch_tensor(label_ids):
464
+ """
465
+ label_ids: (B, L)
466
+ return:
467
+ spans_tensor: (B, N, 2) # (s, e), pad = (0,0)
468
+ """
469
+ B, L = label_ids.shape
470
+ device = label_ids.device
471
+
472
+ all_spans = []
473
+ max_n = 0
474
+
475
+ # ===== extract spans (list trước) =====
476
+ for b in range(B):
477
+ spans = []
478
+ i = 0
479
+
480
+ while i < L:
481
+ tag = label_ids[b, i].item()
482
+
483
+ if tag == 0:
484
+ i += 1
485
+ continue
486
+
487
+ # B- (odd)
488
+ if tag % 2 == 1:
489
+ type_id = (tag - 1) // 2
490
+ s = i
491
+ e = i
492
+ i += 1
493
+
494
+ while i < L:
495
+ next_tag = label_ids[b, i].item()
496
+
497
+ if next_tag == 0:
498
+ break
499
+
500
+ next_type = (next_tag - 1) // 2
501
+
502
+ if next_tag % 2 == 0 and next_type == type_id:
503
+ e = i
504
+ i += 1
505
+ else:
506
+ break
507
+
508
+ spans.append((s, e))
509
+ else:
510
+ i += 1
511
+
512
+ all_spans.append(spans)
513
+ max_n = max(max_n, len(spans))
514
+
515
+ # ===== build tensor =====
516
+ if max_n == 0:
517
+ # không có span nào → return tensor rỗng đúng shape
518
+ return torch.zeros((B, 0, 2), dtype=torch.long, device=device)
519
+
520
+ spans_tensor = torch.zeros((B, max_n, 2), dtype=torch.long, device=device)
521
+
522
+ for b in range(B):
523
+ for i, (s, e) in enumerate(all_spans[b]):
524
+ spans_tensor[b, i, 0] = s
525
+ spans_tensor[b, i, 1] = e
526
+
527
+ return spans_tensor
528
+
529
+ def get_span_repr(hidden, spans):
530
+ B, L, H = hidden.size()
531
+ K = spans.size(1)
532
+ device = hidden.device
533
+
534
+ start = spans[:, :, 0] # (B, K)
535
+ end = spans[:, :, 1] # (B, K)
536
+
537
+ h_s = torch.gather(hidden, 1, start.unsqueeze(-1).expand(-1, -1, H))
538
+ h_e = torch.gather(hidden, 1, end.unsqueeze(-1).expand(-1, -1, H))
539
+
540
+ h_diff = h_s - h_e
541
+ h_prod = h_s * h_e
542
+
543
+ # ===== 6. concat =====
544
+ span_repr = torch.cat(
545
+ [h_s, h_e, h_diff, h_prod],
546
+ dim=-1
547
+ )
548
+
549
+ return span_repr
550
+
551
+ class MLP(nn.Module):
552
+ def __init__(self, in_size, hid_size, out_size):
553
+ super().__init__()
554
+ self.model = nn.Sequential(
555
+ nn.Linear(in_size, hid_size),
556
+ nn.ReLU(),
557
+ nn.Linear(hid_size, out_size)
558
+ )
559
+
560
+ def forward(self, x):
561
+ return self.model(x)
562
+
563
+ class IEModel(nn.Module):
564
+ def __init__(self, backbone_model_name, num_trg_labels, num_arg_labels):
565
+ super().__init__()
566
+ self.encoder = AutoModel.from_pretrained(backbone_model_name)
567
+ hidden_size = self.encoder.config.hidden_size
568
+
569
+ self.trg_classifier = MLP(hidden_size, hidden_size, num_trg_labels)
570
+
571
+ self.trg_repr_proj = MLP(hidden_size*4, hidden_size, hidden_size)
572
+ self.arg_classifier = MLP(hidden_size*2, hidden_size, num_arg_labels)
573
+
574
+ def encode(self, input_ids, attention_mask):
575
+ B, n_parts, L = input_ids.shape
576
+ input_ids = input_ids.view(-1, L)
577
+ attention_mask = attention_mask.view(-1, L)
578
+
579
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
580
+ hidden_states = outputs.last_hidden_state # B * n_parts, L, H
581
+
582
+ hidden_states = hidden_states.view(B, n_parts, L, -1).reshape(B, n_parts*L, -1) # B, L, H
583
+ return hidden_states
584
+
585
+ def get_trg_logits(self, hidden_states):
586
+ trg_logits = self.trg_classifier(hidden_states) # B, N, trg_classes
587
+ return trg_logits
588
+
589
+ def get_arg_logits(self, hidden_states, trg_repr):
590
+ B, L, H = hidden_states.shape
591
+ _, N, _ = trg_repr.shape
592
+
593
+ hidden_expand = hidden_states.unsqueeze(1).expand(-1, N, -1, -1)
594
+ trg_expand = trg_repr.unsqueeze(2).expand(-1, -1, L, -1)
595
+
596
+ hidden_trg_repr = torch.cat([hidden_expand, trg_expand], dim=-1) # (B, N, L, 2H)
597
+ arg_logits = self.arg_classifier(hidden_trg_repr) # (B, N, L, C)
598
+
599
+ return arg_logits
600
+
601
+ def forward(self, input_ids, attention_mask, trg_spans=None):
602
+ hidden_states = self.encode(input_ids, attention_mask)
603
+
604
+ trg_logits = self.get_trg_logits(hidden_states)
605
+
606
+ if trg_spans is None:
607
+ trg_labels = torch.argmax(trg_logits, dim=-1)
608
+ trg_labels = fix_bio_ids_batch(trg_labels)
609
+ trg_spans = extract_trigger_spans_batch_tensor(trg_labels)
610
+
611
+ trg_repr = get_span_repr(hidden_states, trg_spans) # B, N, 4H
612
+
613
+ trg_repr = self.trg_repr_proj(trg_repr) # B, N, H
614
+ arg_logits = self.get_arg_logits(hidden_states, trg_repr)
615
+
616
+ return trg_logits, arg_logits, trg_spans
617
+
618
+ def test():
619
+ model = nn.DataParallel(IEModel(backbone_model_name, 7, 5)).to(device)
620
+ model.eval()
621
+ total_params = sum(p.numel() for p in model.parameters())
622
+ print(f"Total params: {total_params:,}")
623
+
624
+ vocab_size = model.module.encoder.config.vocab_size
625
+ max_len = model.module.encoder.config.max_position_embeddings
626
+
627
+ bz = 32
628
+ i = torch.randint(0, vocab_size, (bz, 5, 10)).to(device)
629
+ a = torch.ones(bz, 5, 10).to(device)
630
+ g = torch.ones(bz, 3, 2, dtype=torch.long).to(device)
631
+
632
+ with torch.no_grad():
633
+ r = model(i, a, g)
634
+
635
+ if type(r) == tuple:
636
+ print([r[i].shape for i in range(len(r))])
637
+ else:
638
+ print(r.shape)
639
+
640
+ test()
641
+
642
+ # %% [code]
643
+ def configure_optimizers(network, optim_params, scheduler_params):
644
+ try:
645
+ optim_params = copy.copy(optim_params)
646
+ scheduler_params = copy.copy(scheduler_params)
647
+
648
+ optim_name = optim_params.pop('name')
649
+ scheduler_name = scheduler_params.pop('name')
650
+
651
+ optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None)
652
+ scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None)
653
+
654
+ if optimizer_cls is None:
655
+ raise ValueError(f"Optimizer '{optim_name}' is not available!")
656
+
657
+ optimizer = optimizer_cls(network.parameters(), **optim_params)
658
+
659
+ scheduler = None
660
+ if scheduler_params and scheduler_cls: # Chỉ tạo scheduler nếu có tham số
661
+ scheduler = scheduler_cls(optimizer, **scheduler_params)
662
+
663
+ return optimizer, scheduler
664
+
665
+ except KeyError as e:
666
+ raise ValueError(f"Missing {e} in config!!")
667
+
668
+ def freeze(self, model):
669
+ model.eval()
670
+ for param in model.parameters():
671
+ param.requires_grad = False
672
+
673
+ def unfreeze(self, model):
674
+ model.train()
675
+ for param in model.parameters():
676
+ param.requires_grad = True
677
+
678
+ def reduce_batch_size(loader, ratio=0.5):
679
+ new_bs = max(1, int(loader.batch_size * ratio))
680
+
681
+ shuffle = isinstance(loader.sampler, RandomSampler)
682
+
683
+ new_loader = DataLoader(
684
+ dataset=loader.dataset,
685
+ batch_size=new_bs,
686
+ shuffle=shuffle,
687
+ sampler=None if shuffle else loader.sampler,
688
+ num_workers=loader.num_workers,
689
+ collate_fn=loader.collate_fn,
690
+ pin_memory=loader.pin_memory,
691
+ drop_last=loader.drop_last,
692
+ timeout=loader.timeout,
693
+ worker_init_fn=loader.worker_init_fn,
694
+ multiprocessing_context=loader.multiprocessing_context,
695
+ generator=loader.generator,
696
+ prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None,
697
+ persistent_workers=loader.persistent_workers,
698
+ pin_memory_device=loader.pin_memory_device
699
+ )
700
+
701
+ return new_loader
702
+
703
+ def list_to_tuple(x):
704
+ if isinstance(x, (list, tuple)):
705
+ return tuple(list_to_tuple(i) for i in x)
706
+ return x
707
+
708
+ def fmt(x):
709
+ if isinstance(x, float):
710
+ return round(x, 5)
711
+ if isinstance(x, dict):
712
+ return {k: fmt(v) for k, v in x.items()}
713
+ if isinstance(x, list):
714
+ return [fmt(v) for v in x]
715
+ return x
716
+
717
+ class ModelEmaV3Proxy(ModelEmaV3):
718
+ def __getattr__(self, name):
719
+ try:
720
+ return super().__getattr__(name)
721
+ except AttributeError:
722
+ return getattr(self.module, name)
723
+
724
+ class DataParallelProxy(nn.DataParallel):
725
+ def __getattr__(self, name):
726
+ try:
727
+ return super().__getattr__(name)
728
+ except AttributeError:
729
+ attr = getattr(self.module, name)
730
+
731
+ if callable(attr):
732
+ def wrapper(*args, **kwargs):
733
+ return self._parallel_apply_method(name, *args, **kwargs)
734
+ return wrapper
735
+
736
+ return attr
737
+
738
+ def _parallel_apply_method(self, method_name, *inputs, **kwargs):
739
+ if not self.device_ids:
740
+ return getattr(self.module, method_name)(*inputs, **kwargs)
741
+
742
+ inputs_scattered, kwargs_scattered = self.scatter(inputs, kwargs, self.device_ids)
743
+
744
+ replicas = self.replicate(self.module, self.device_ids)
745
+
746
+ outputs = self.parallel_apply(
747
+ [getattr(replica, method_name) for replica in replicas],
748
+ inputs_scattered,
749
+ kwargs_scattered
750
+ )
751
+
752
+ return self.gather(outputs, self.output_device)
753
+
754
+ def map_arg_labels(all_arg_labels, trg_spans, pred_spans):
755
+ """
756
+ all_arg_labels: (B, N, L)
757
+ trg_spans: (B, N, 2)
758
+ pred_spans: (B, M, 2)
759
+
760
+ return:
761
+ pred_arg_labels: (B, M, L)
762
+ """
763
+ B, N, L = all_arg_labels.shape
764
+ _, M, _ = pred_spans.shape
765
+
766
+ device = all_arg_labels.device
767
+
768
+ # ===== match (B, M, N) =====
769
+ match = (
770
+ (pred_spans.unsqueeze(2) == trg_spans.unsqueeze(1))
771
+ .all(dim=-1)
772
+ )
773
+
774
+ # ===== index match =====
775
+ match_idx = match.float().argmax(dim=2) # (B, M)
776
+ has_match = match.any(dim=2) # (B, M)
777
+
778
+ # ===== gather =====
779
+ gather_idx = match_idx.unsqueeze(-1).expand(-1, -1, L) # (B, M, L)
780
+
781
+ gathered = torch.gather(
782
+ all_arg_labels,
783
+ dim=1,
784
+ index=gather_idx
785
+ ) # (B, M, L)
786
+
787
+ # ===== build output =====
788
+ # base = 0 nhưng giữ -100
789
+ base = torch.zeros((B, M, L), dtype=torch.long, device=device)
790
+
791
+ # mask vị trí -100 từ source (lấy từ n=0 cũng được vì mask thường giống nhau)
792
+ ignore_mask = (all_arg_labels[:, 0] == -100).unsqueeze(1).expand(-1, M, -1)
793
+ base[ignore_mask] = -100
794
+
795
+ # ===== fill match =====
796
+ pred_arg_labels = torch.where(
797
+ has_match.unsqueeze(-1), # (B, M, 1)
798
+ gathered,
799
+ base
800
+ )
801
+
802
+ return pred_arg_labels.long()
803
+
804
+ def fix_bio_ids(label_ids):
805
+ fixed = []
806
+
807
+ for i, tag in enumerate(label_ids):
808
+ if tag == 0:
809
+ fixed.append(0)
810
+ continue
811
+
812
+ # I- (even)
813
+ if tag % 2 == 0:
814
+ if i == 0 or fixed[i-1] == 0:
815
+ # I- đứng đầu hoặc sau O → chuyển thành B-
816
+ tag = tag - 1
817
+ else:
818
+ prev_tag = fixed[i-1]
819
+
820
+ # nếu prev là O thì cũng convert
821
+ if prev_tag == 0:
822
+ tag = tag - 1
823
+ else:
824
+ prev_type = (prev_tag - 1) // 2
825
+ curr_type = (tag - 1) // 2
826
+
827
+ if prev_type != curr_type:
828
+ tag = tag - 1 # I-X → B-X
829
+
830
+ fixed.append(tag)
831
+
832
+ return fixed
833
+
834
+ def extract_trigger_spans_batch(label_ids: torch.Tensor):
835
+ """
836
+ label_ids: (B, L)
837
+ return:
838
+ List[List[(s, e, type_id)]]
839
+ """
840
+ B, L = label_ids.shape
841
+ results = []
842
+
843
+ for b in range(B):
844
+ spans = []
845
+ i = 0
846
+
847
+ while i < L:
848
+ tag = label_ids[b, i].item()
849
+
850
+ if tag == 0:
851
+ i += 1
852
+ continue
853
+
854
+ # B- (odd)
855
+ if tag % 2 == 1:
856
+ type_id = (tag - 1) // 2
857
+ s = i
858
+ e = i
859
+ i += 1
860
+
861
+ while i < L:
862
+ next_tag = label_ids[b, i].item()
863
+
864
+ if next_tag == 0:
865
+ break
866
+
867
+ next_type = (next_tag - 1) // 2
868
+
869
+ if next_tag % 2 == 0 and next_type == type_id:
870
+ e = i
871
+ i += 1
872
+ else:
873
+ break
874
+
875
+ spans.append((s, e, type_id))
876
+
877
+ else:
878
+ # I- không hợp lệ → skip (đã fix trước rồi)
879
+ i += 1
880
+
881
+ results.append(spans)
882
+
883
+ return results
884
+
885
+ def extract_arguments(
886
+ input_ids,
887
+ trg_logits,
888
+ arg_logits,
889
+ pred_trg_spans,
890
+ id2label
891
+ ):
892
+ """
893
+ input_ids: (B, L)
894
+ trg_logits: (B, L, C_trg)
895
+ arg_logits: (B, N, L, C_arg)
896
+ pred_trg_spans: (B, N, 2)
897
+
898
+ id2label = {
899
+ 'Trg': {id: 'B-XXX'/'I-XXX'/...},
900
+ 'Arg': {id: 'B-XXX'/'I-XXX'/...}
901
+ }
902
+ """
903
+
904
+ def strip_bio(label):
905
+ if label == 'O':
906
+ return 'O'
907
+ return label[2:] # bỏ B-/I-
908
+
909
+ B, L = input_ids.shape
910
+
911
+ # ===== decode trigger =====
912
+ trg_ids = torch.argmax(trg_logits, dim=-1) # (B, L)
913
+ trg_ids = fix_bio_ids_batch(trg_ids)
914
+
915
+ # ===== extract trigger spans =====
916
+ trg_spans = extract_trigger_spans_batch(trg_ids)
917
+
918
+ results = []
919
+
920
+ for bidx in range(B):
921
+ # map span → label (string, đã bỏ BIO)
922
+ span2label = {
923
+ (s, e): strip_bio(id2label['Trg'][t_id])
924
+ for (s, e, t_id) in trg_spans[bidx]
925
+ }
926
+
927
+ for n in range(pred_trg_spans.shape[1]):
928
+ s_trg = pred_trg_spans[bidx, n, 0].item()
929
+ e_trg = pred_trg_spans[bidx, n, 1].item()
930
+
931
+ # skip padding
932
+ if s_trg == 0 and e_trg == 0:
933
+ continue
934
+
935
+ if (s_trg, e_trg) not in span2label:
936
+ continue
937
+
938
+ trg_label = span2label[(s_trg, e_trg)]
939
+ trg_tokens = input_ids[bidx, s_trg:e_trg+1].tolist()
940
+
941
+ # ===== argument =====
942
+ arg_ids = torch.argmax(arg_logits[bidx, n], dim=-1) # (L,)
943
+ arg_ids = fix_bio_ids(arg_ids.tolist())
944
+
945
+ i = 0
946
+ while i < L:
947
+ tag = arg_ids[i]
948
+
949
+ if tag == 0:
950
+ i += 1
951
+ continue
952
+
953
+ # B-
954
+ if tag % 2 == 1:
955
+ arg_type_id = (tag - 1) // 2
956
+ s_arg = i
957
+ e_arg = i
958
+ i += 1
959
+
960
+ while i < L:
961
+ next_tag = arg_ids[i]
962
+ if next_tag == 0:
963
+ break
964
+
965
+ next_type = (next_tag - 1) // 2
966
+ if next_tag % 2 == 0 and next_type == arg_type_id:
967
+ e_arg = i
968
+ i += 1
969
+ else:
970
+ break
971
+
972
+ # lấy label string và bỏ BIO
973
+ raw_label = id2label['Arg'][tag]
974
+ arg_label = strip_bio(raw_label)
975
+
976
+ arg_tokens = input_ids[bidx, s_arg:e_arg+1].tolist()
977
+
978
+ results.append((
979
+ bidx,
980
+ (tuple(trg_tokens), trg_label),
981
+ (tuple(arg_tokens), arg_label)
982
+ ))
983
+
984
+ else:
985
+ i += 1
986
+
987
+ return results
988
+
989
+ class Trainer:
990
+ def __init__(
991
+ self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0,
992
+ logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu',
993
+ schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True
994
+ ):
995
+ self.ema_net = None
996
+
997
+ self.training_time = self._time_str_to_seconds(training_time)
998
+ self.mode = eval_mode
999
+ self.topk = topk
1000
+ self.device = device
1001
+ self.logging = logging if logging < epochs else 1
1002
+ self.logging_file = logging_file
1003
+ self.checkpoints_dir = checkpoints_dir
1004
+ self.early_stopping = early_stopping
1005
+ self.eval_from_ratio = eval_from_ratio
1006
+ self.eval_every = eval_every
1007
+ self.save_name = save_name
1008
+ self.save_best = save_best
1009
+ self.save_last = save_last
1010
+ self.return_best = return_best
1011
+ self.return_last = return_last
1012
+ self.max_grad_norm = max_grad_norm
1013
+ self.schedule_in_step = schedule_in_step
1014
+ self.use_ema = use_ema
1015
+ self.ema_from_ratio = ema_from_ratio
1016
+ self.ema_decay = ema_decay
1017
+
1018
+ self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]]
1019
+ self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0)
1020
+
1021
+ def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, id2label=None):
1022
+ if eval_fn is None:
1023
+ if self.mode == "max":
1024
+ eval_fn = lambda *x: -loss_fn(*x)
1025
+ else:
1026
+ eval_fn = lambda *x: loss_fn(*x)
1027
+
1028
+ if torch.cuda.device_count() > 1:
1029
+ network = DataParallelProxy(network)
1030
+ network = network.to(self.device)
1031
+
1032
+ if not start_training_time:
1033
+ start_training_time = time.time()
1034
+
1035
+ start_ema = int(epochs * self.ema_from_ratio)
1036
+ start_eval = int(epochs * self.eval_from_ratio)
1037
+
1038
+ if val_loader is None:
1039
+ print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!')
1040
+ else:
1041
+ model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc'
1042
+ start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else ''
1043
+ print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str)
1044
+
1045
+ training_log = {}
1046
+ for epoch in range(start_epoch, epochs+start_epoch):
1047
+ if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema:
1048
+ self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
1049
+
1050
+ try:
1051
+ teaching_rate = math.cos(math.pi / 2 * epoch / epochs)
1052
+ train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, teaching_rate)
1053
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
1054
+ logging_dict.update(train_loss_epoch_dict)
1055
+
1056
+ if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0:
1057
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1058
+
1059
+ val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, id2label)
1060
+ update = self._update_best_network(eval_net, val_score, epoch)
1061
+ logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update})
1062
+ logging_dict.update(val_score_dict)
1063
+ if not self.schedule_in_step and scheduler:
1064
+ scheduler.step()
1065
+
1066
+ except RuntimeError as e:
1067
+ if "out of memory" in str(e).lower():
1068
+ print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...")
1069
+ torch.cuda.empty_cache()
1070
+ gc.collect()
1071
+ if torch.cuda.is_available():
1072
+ torch.cuda.synchronize()
1073
+ print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared")
1074
+
1075
+ train_loader = reduce_batch_size(train_loader, ratio=0.5)
1076
+ if val_loader is not None:
1077
+ val_loader = reduce_batch_size(val_loader, ratio=0.5)
1078
+
1079
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')}
1080
+ else:
1081
+ raise
1082
+
1083
+ training_log[epoch] = logging_dict
1084
+ if self.is_early_stopping(epoch):
1085
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...')
1086
+ break
1087
+ if self.logging:
1088
+ if epoch % self.logging == 0:
1089
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict))
1090
+ else:
1091
+ print(f'{epoch}...', end=' ')
1092
+
1093
+ if self._at_time_limit(start_training_time):
1094
+ print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}')
1095
+ break
1096
+
1097
+ if self.logging_file:
1098
+ os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True)
1099
+ with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f:
1100
+ f.write(json.dumps(training_log))
1101
+
1102
+ if self.use_ema and self.ema_net is not None:
1103
+ self._save_state_dict(self.ema_net.module)
1104
+ else:
1105
+ self._save_state_dict(network)
1106
+ print(f'[Trainer CallBack] 📢 Kết thúc training.\n')
1107
+
1108
+ best_model, last_model = None, None
1109
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1110
+ if self.return_best :
1111
+ best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict()
1112
+ best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()}
1113
+ if self.return_last:
1114
+ last_model = eval_net.state_dict()
1115
+ last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()}
1116
+
1117
+ del network
1118
+ torch.cuda.empty_cache()
1119
+ gc.collect()
1120
+ return training_log, best_model, last_model
1121
+
1122
+ def _time_str_to_seconds(self, time_str):
1123
+ days, hours, minutes, seconds = map(int, time_str.split(":"))
1124
+ return days * 86400 + hours * 3600 + minutes * 60 + seconds
1125
+
1126
+ def _update_best_network(self, network, val_score, epoch):
1127
+ topk = max(1, self.topk)
1128
+ self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}])
1129
+ self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk]
1130
+ if val_score in [x[0] for x in self.best_stage]:
1131
+ return True
1132
+ return False
1133
+
1134
+ def is_early_stopping(self, epoch):
1135
+ if self.best_stage[0][1] is None:
1136
+ return False
1137
+ if not self.early_stopping:
1138
+ return False
1139
+ return epoch - self.best_stage[0][1] >= self.early_stopping
1140
+
1141
+ def _at_time_limit(self, start_training_time):
1142
+ return time.time() - start_training_time >= self.training_time
1143
+
1144
+ def _save_state_dict(self, network):
1145
+ if self.topk <= 0:
1146
+ return
1147
+
1148
+ if self.save_best:
1149
+ for r in range(self.topk):
1150
+ os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True)
1151
+
1152
+ for rank, (score, epoch, state_dict) in enumerate(self.best_stage):
1153
+ if state_dict is None:
1154
+ continue
1155
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()}
1156
+ torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth')
1157
+ if self.save_last:
1158
+ os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True)
1159
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()}
1160
+ torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth')
1161
+
1162
+ def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn, teaching_rate):
1163
+ network.train()
1164
+ total_loss = 0
1165
+ total_loss_dict = {}
1166
+ for batch_idx, batch in enumerate(train_loader):
1167
+ optimizer.zero_grad()
1168
+ with torch.autocast(device_type=self.device, dtype=torch.float16):
1169
+ loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn, teaching_rate)
1170
+
1171
+ for k, v in loss_dict.items():
1172
+ t = total_loss_dict.get(k, 0)
1173
+ total_loss_dict[k] = t + v
1174
+ self.grad_scaler.scale(loss).backward()
1175
+ self.grad_scaler.unscale_(optimizer)
1176
+ grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm)
1177
+ # print(grad_norm) # Bỏ cmt dòng này để biết nên chọn max_grad_norm bằng bao nhiêu...
1178
+ self.grad_scaler.step(optimizer)
1179
+ self.grad_scaler.update()
1180
+ if self.schedule_in_step and scheduler:
1181
+ scheduler.step()
1182
+ if self.use_ema and self.ema_net is not None:
1183
+ self.ema_net.update(network)
1184
+ total_loss += loss
1185
+ return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()}
1186
+
1187
+ def _eval_epoch(self, network, val_loader, eval_fn, id2label):
1188
+ network.eval()
1189
+ total_score = 0.0
1190
+ total_score_dict = {}
1191
+ object_lists = None # sẽ init sau
1192
+
1193
+ with torch.no_grad():
1194
+ for batch_idx, batch in enumerate(val_loader):
1195
+ score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, id2label)
1196
+ total_score += score
1197
+
1198
+ for k, v in score_dict.items():
1199
+ t = total_score_dict.get(k, 0)
1200
+ total_score_dict[k] = t + v
1201
+
1202
+ if objects:
1203
+ if object_lists is None:
1204
+ object_lists = [[] for _ in range(len(objects))]
1205
+
1206
+ for i, obj in enumerate(objects):
1207
+ object_lists[i].append(obj.detach())
1208
+
1209
+ if object_lists is not None:
1210
+ object_arrays = [
1211
+ torch.concat(obj_list, dim=0).cpu().numpy()
1212
+ for obj_list in object_lists
1213
+ ]
1214
+ else:
1215
+ object_arrays = []
1216
+
1217
+ return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays
1218
+
1219
+ def _cal_loss(self, network, batch, batch_idx, loss_fn, teaching_rate):
1220
+ # Bạn cần override _cal_loss để tính loss
1221
+ input_ids = batch['input_ids'].to(self.device)
1222
+ attention_mask = batch['attention_mask'].to(self.device)
1223
+ trg_spans = batch['trg_spans'].to(self.device) # B, M, 2
1224
+ trg_labels = batch['trg_labels'].to(self.device) # B, L
1225
+ all_arg_labels = batch['all_arg_labels'].to(self.device) # B, M, L
1226
+
1227
+ hidden_states = network.encode(input_ids, attention_mask)
1228
+ trg_logits = network.get_trg_logits(hidden_states)
1229
+
1230
+ choice = random.random()
1231
+ if choice < teaching_rate:
1232
+ pred_trg_spans = trg_spans
1233
+ else:
1234
+ pred_trg_labels = torch.argmax(trg_logits, dim=-1)
1235
+ pred_trg_labels = fix_bio_ids_batch(pred_trg_labels)
1236
+ pred_trg_spans = extract_trigger_spans_batch_tensor(pred_trg_labels)
1237
+
1238
+ trg_repr = get_span_repr(hidden_states, pred_trg_spans) # B, N, 4H
1239
+
1240
+ trg_repr = network.trg_repr_proj(trg_repr) # B, N, H
1241
+ arg_logits = network.get_arg_logits(hidden_states, trg_repr)
1242
+ pred_arg_labels = map_arg_labels(all_arg_labels, trg_spans, pred_trg_spans)
1243
+
1244
+ loss_dict = loss_fn(
1245
+ trg_logits, trg_labels,
1246
+ arg_logits, pred_arg_labels,
1247
+ )
1248
+ return loss_dict['total'], loss_dict
1249
+
1250
+ def _cal_val_score(self, network, batch, batch_idx, eval_fn, id2label):
1251
+ # Bạn cần override _cal_val_score để tính val score, list bên cạnh là để trả về y hay pred gì đó (nếu cần)
1252
+ input_ids = batch['input_ids'].to(self.device)
1253
+ attention_mask = batch['attention_mask'].to(self.device)
1254
+ gold_events = batch['gold_events']
1255
+
1256
+ B, _, _ = input_ids.shape
1257
+
1258
+ hidden_states = network.encode(input_ids, attention_mask)
1259
+ trg_logits = network.get_trg_logits(hidden_states)
1260
+
1261
+ pred_trg_labels = torch.argmax(trg_logits, dim=-1)
1262
+ pred_trg_labels = fix_bio_ids_batch(pred_trg_labels)
1263
+ pred_trg_spans = extract_trigger_spans_batch_tensor(pred_trg_labels)
1264
+ trg_repr = get_span_repr(hidden_states, pred_trg_spans) # B, N, 4H
1265
+
1266
+ trg_repr = network.trg_repr_proj(trg_repr) # B, N, H
1267
+ arg_logits = network.get_arg_logits(hidden_states, trg_repr)
1268
+
1269
+ pred_ids = extract_arguments(input_ids.reshape(B, -1), trg_logits, arg_logits, pred_trg_spans, id2label)
1270
+ pred_ids = list_to_tuple(pred_ids)
1271
+
1272
+ gold_ids = list_to_tuple(gold_events)
1273
+
1274
+ score_dict = eval_fn(pred_ids, gold_ids)
1275
+ return score_dict['f1'], score_dict, []
1276
+
1277
+ # %% [code]
1278
+ class PhoBERTSpanAligner:
1279
+ def __init__(self, tokenizer, max_len):
1280
+ self.tokenizer = tokenizer
1281
+ self.max_len = max_len
1282
+
1283
+ # ===== 1. Extract discontinuous spans =====
1284
+ def extract_spans(self, sample):
1285
+ trigger_spans, arg_spans = [], []
1286
+
1287
+ for event in sample["events"]:
1288
+ trigger_type = event["label"]
1289
+ spans = [tuple(event["offset"])]
1290
+ trigger_spans.append({
1291
+ "spans": spans,
1292
+ "label": trigger_type
1293
+ })
1294
+ event_arg_spans = []
1295
+ for arg in event['arguments']:
1296
+ arg_type = arg["role"]
1297
+ spans = [tuple(arg["offset"])]
1298
+ event_arg_spans.append({
1299
+ "spans": spans,
1300
+ "label": arg_type
1301
+ })
1302
+ arg_spans.append(event_arg_spans)
1303
+
1304
+ return trigger_spans, arg_spans
1305
+
1306
+ # ===== 2. Word offsets =====
1307
+ def build_word_offsets(self, text, words):
1308
+ offsets = []
1309
+ pointer = 0
1310
+
1311
+ for word in words:
1312
+ start = text.find(word, pointer)
1313
+ end = start + len(word)
1314
+ offsets.append((start, end))
1315
+ pointer = end
1316
+
1317
+ return offsets
1318
+
1319
+ # ===== 3. Char → word =====
1320
+ def char_span_to_word_span(self, word_offsets, start, end):
1321
+ start_word = None
1322
+ end_word = None
1323
+
1324
+ for i, (w_start, w_end) in enumerate(word_offsets):
1325
+ if w_start <= start < w_end:
1326
+ start_word = i
1327
+ if w_start < end <= w_end:
1328
+ end_word = i
1329
+
1330
+ return start_word, end_word
1331
+
1332
+ # ===== 4. Word → subword =====
1333
+ def word_to_subword_map(self, words):
1334
+ mapping = []
1335
+ subword_index = 1 # <s>
1336
+
1337
+ for word in words:
1338
+ sub_tokens = self.tokenizer.tokenize(word)
1339
+ start = subword_index
1340
+ end = subword_index + len(sub_tokens) - 1
1341
+ mapping.append((start, end))
1342
+ subword_index += len(sub_tokens)
1343
+
1344
+ return mapping
1345
+
1346
+ # ===== 5. Span → subword =====
1347
+ def span_to_subword(self, word_offsets, word_subword_map, spans):
1348
+ sub_spans = []
1349
+
1350
+ for span_start, span_end in spans:
1351
+ w_start, w_end = self.char_span_to_word_span(
1352
+ word_offsets, span_start, span_end
1353
+ )
1354
+ if w_start is None or w_end is None:
1355
+ continue
1356
+
1357
+ sub_start = word_subword_map[w_start][0]
1358
+ sub_end = word_subword_map[w_end][1]
1359
+ sub_spans.append((sub_start, sub_end))
1360
+
1361
+ return sub_spans
1362
+
1363
+ def extract_valid_spans(self, sub_spans):
1364
+ valid_spans = []
1365
+ for s, e in sub_spans:
1366
+ if s < 0 or e < 0 or s >= self.max_len or e >= self.max_len or s > e:
1367
+ continue
1368
+ valid_spans.append((s, e))
1369
+ return valid_spans
1370
+
1371
+ def encode(self, sample):
1372
+ text = sample["text"]
1373
+ triggers, arguments = self.extract_spans(sample)
1374
+
1375
+ # ===== 1. Word tokenize =====
1376
+ words = word_tokenize(text)
1377
+ sentence = " ".join(words)
1378
+
1379
+ # ===== 2. Mapping =====
1380
+ word_offsets = self.build_word_offsets(text, words)
1381
+ word_subword_map = self.word_to_subword_map(words)
1382
+
1383
+ # ===== 3. Tokenize FULL =====
1384
+ encoding = self.tokenizer(
1385
+ sentence,
1386
+ max_length=self.max_len,
1387
+ truncation=True,
1388
+ padding="max_length",
1389
+ return_tensors="pt"
1390
+ )
1391
+ input_ids = encoding["input_ids"][0]
1392
+ attention_mask = encoding["attention_mask"][0]
1393
+
1394
+ # ===== 5. Convert spans =====
1395
+ triggers_gold_spans = []
1396
+ arguments_gold_spans = []
1397
+
1398
+ for trg, args in zip(triggers, arguments):
1399
+ label = trg["label"]
1400
+
1401
+ sub_spans = self.span_to_subword(
1402
+ word_offsets,
1403
+ word_subword_map,
1404
+ trg["spans"]
1405
+ )
1406
+ valid_spans = self.extract_valid_spans(sub_spans)
1407
+ if len(valid_spans) == 0:
1408
+ continue
1409
+ triggers_gold_spans.append((tuple(valid_spans), label))
1410
+
1411
+ trg_args_gold_spans = []
1412
+ for arg in args:
1413
+ label = arg["label"]
1414
+
1415
+ sub_spans = self.span_to_subword(
1416
+ word_offsets,
1417
+ word_subword_map,
1418
+ arg["spans"]
1419
+ )
1420
+ valid_spans = self.extract_valid_spans(sub_spans)
1421
+ if len(valid_spans) == 0:
1422
+ continue
1423
+ trg_args_gold_spans.append((tuple(valid_spans), label))
1424
+ arguments_gold_spans.append(tuple(trg_args_gold_spans))
1425
+
1426
+ return {
1427
+ "input_ids": input_ids,
1428
+ "attention_mask": attention_mask,
1429
+ "triggers_gold_spans": triggers_gold_spans,
1430
+ "arguments_gold_spans": arguments_gold_spans,
1431
+ }
1432
+
1433
+ def generate_candidate_spans(seq_len, max_span_len):
1434
+ spans = []
1435
+ for i in range(1, seq_len+1):
1436
+ for j in range(i, min(i+max_span_len, seq_len+1)):
1437
+ spans.append((i, j))
1438
+ return spans
1439
+
1440
+ class KLTNDataset(Dataset):
1441
+ def __init__(self, all_data, using_idxes, label2id, tokenizer, max_len, max_n_parts):
1442
+ super().__init__()
1443
+ self.tokenizer = tokenizer
1444
+ self.aligner = PhoBERTSpanAligner(tokenizer, max_len*max_n_parts)
1445
+ self.all_data = all_data
1446
+ self.using_idxes = using_idxes
1447
+ self.label2id = label2id
1448
+ self.max_len = max_len
1449
+ self.max_n_parts = max_n_parts
1450
+
1451
+ def __len__(self):
1452
+ return len(self.using_idxes)
1453
+
1454
+ def __getitem__(self, idx):
1455
+ ridx = self.using_idxes[idx]
1456
+ sample = self.all_data[ridx]
1457
+ result = self.aligner.encode(sample)
1458
+
1459
+ input_ids = result["input_ids"].squeeze(0)
1460
+ attention_mask = result["attention_mask"].squeeze(0)
1461
+ triggers_gold_spans = result["triggers_gold_spans"]
1462
+ arguments_gold_spans = result["arguments_gold_spans"]
1463
+
1464
+ # Get event label
1465
+ all_trg_spans = torch.tensor([list(trg_spans[0]) for trg_spans, _ in triggers_gold_spans], dtype=torch.long) if triggers_gold_spans else torch.empty(0, 2, dtype=torch.long)
1466
+ gold_events = []
1467
+ trg_labels = torch.ones(input_ids.size(0), dtype=torch.long) * (-100) * (1 - attention_mask)
1468
+ all_arg_labels = []
1469
+ for (trg_spans, trg_label), args in zip(triggers_gold_spans, arguments_gold_spans):
1470
+ s, e = trg_spans[0]
1471
+
1472
+ trg_labels[s] = self.label2id['Trg'][f'B-{trg_label}']
1473
+ trg_labels[s+1:e+1] = self.label2id['Trg'][f'I-{trg_label}']
1474
+
1475
+ event = [(tuple(input_ids[s:e+1].tolist()), trg_label)]
1476
+
1477
+ arg_labels = torch.ones(input_ids.size(0), dtype=torch.long) * (-100) * (1 - attention_mask)
1478
+ for arg_spans, arg_label in args:
1479
+ s, e = arg_spans[0]
1480
+
1481
+ arg_labels[s] = self.label2id['Arg'][f'B-{arg_label}']
1482
+ arg_labels[s+1:e+1] = self.label2id['Arg'][f'I-{arg_label}']
1483
+
1484
+ event.append((tuple(input_ids[s:e+1].tolist()), arg_label))
1485
+ all_arg_labels.append(arg_labels)
1486
+
1487
+ gold_events.append(event)
1488
+
1489
+ input_ids = input_ids.reshape(self.max_n_parts, self.max_len)
1490
+ attention_mask = attention_mask.reshape(self.max_n_parts, self.max_len)
1491
+
1492
+ n_valid_parts = math.ceil(attention_mask.sum().item() / self.max_len)
1493
+ input_ids = input_ids[:n_valid_parts]
1494
+ attention_mask = attention_mask[:n_valid_parts]
1495
+ trg_labels = trg_labels[:n_valid_parts*self.max_len]
1496
+ all_arg_labels = torch.stack([arg_labels[:n_valid_parts*self.max_len] for arg_labels in all_arg_labels], dim=0) if all_arg_labels else torch.empty(0, n_valid_parts*self.max_len)
1497
+
1498
+ return {
1499
+ "input_ids": input_ids,
1500
+ "attention_mask": attention_mask,
1501
+ "trg_spans": all_trg_spans,
1502
+ "trg_labels": trg_labels,
1503
+ "all_arg_labels": all_arg_labels,
1504
+ "gold_events": gold_events,
1505
+ }
1506
+
1507
+ def _pad_batch(tensor_list, pad_value=0):
1508
+ """
1509
+ tensor_list: list of tensors
1510
+ mỗi tensor shape: (Nk, n_parts_i, max_len_i)
1511
+
1512
+ return:
1513
+ padded tensor shape: (B, max_Nk, max_n_parts, max_len)
1514
+ """
1515
+
1516
+ # lấy max toàn batch
1517
+ max_Nk = max(t.size(0) for t in tensor_list)
1518
+ max_n_parts = max(t.size(1) for t in tensor_list)
1519
+ max_len = max(t.size(2) for t in tensor_list)
1520
+
1521
+ padded = []
1522
+
1523
+ for t in tensor_list:
1524
+ Nk, n_parts_i, max_len_i = t.shape
1525
+
1526
+ # pad chiều n_parts và max_len trước
1527
+ if n_parts_i < max_n_parts or max_len_i < max_len:
1528
+ new_t = t.new_full(
1529
+ (Nk, max_n_parts, max_len),
1530
+ pad_value
1531
+ )
1532
+ new_t[:, :n_parts_i, :max_len_i] = t
1533
+ t = new_t
1534
+
1535
+ # pad chiều Nk
1536
+ if Nk < max_Nk:
1537
+ pad_tensor = t.new_full(
1538
+ (max_Nk - Nk, max_n_parts, max_len),
1539
+ pad_value
1540
+ )
1541
+ t = torch.cat([t, pad_tensor], dim=0)
1542
+
1543
+ padded.append(t)
1544
+
1545
+ return torch.stack(padded) # (B, max_Nk, max_n_parts, max_len)
1546
+
1547
+ def collate_fn(batch):
1548
+ gold_events = []
1549
+ for bidx, b in enumerate(batch):
1550
+ for event in b['gold_events']:
1551
+ trg = event[0]
1552
+ if len(event) > 1:
1553
+ for arg in event[1:]:
1554
+ gold_events.append([bidx, trg, arg])
1555
+ else:
1556
+ gold_events.append([bidx, trg, (tuple([]), 0)])
1557
+
1558
+ input_ids = [b["input_ids"].unsqueeze(-1) for b in batch]
1559
+ attention_mask = [b["attention_mask"].unsqueeze(-1) for b in batch]
1560
+ trg_spans = [b["trg_spans"].unsqueeze(-1) for b in batch]
1561
+ trg_labels = [b["trg_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1562
+ all_arg_labels = [b["all_arg_labels"].unsqueeze(-1) for b in batch]
1563
+
1564
+ # pad theo Nk
1565
+ input_ids = _pad_batch(input_ids, pad_value=0).squeeze(-1)
1566
+ attention_mask = _pad_batch(attention_mask, pad_value=0).squeeze(-1)
1567
+ trg_spans = _pad_batch(trg_spans, pad_value=0).squeeze(-1)
1568
+ trg_labels = _pad_batch(trg_labels, pad_value=-100).squeeze(-1).squeeze(-1)
1569
+ all_arg_labels = _pad_batch(all_arg_labels, pad_value=-100).squeeze(-1)
1570
+
1571
+ return {
1572
+ "input_ids": input_ids,
1573
+ "attention_mask": attention_mask,
1574
+ "trg_spans": trg_spans,
1575
+ "trg_labels": trg_labels,
1576
+ "all_arg_labels": all_arg_labels,
1577
+ "gold_events": gold_events,
1578
+ }
1579
+
1580
+ # %% [code]
1581
+ def shift_bidx(spans, batch_idx):
1582
+ shifted = []
1583
+ for bidx, trg, arg in spans:
1584
+ new_bidx = bidx + batch_idx * batch_size
1585
+ shifted.append((new_bidx, trg, arg))
1586
+ return shifted
1587
+
1588
+ def refactor_events(events, save_dict):
1589
+ trg_i, trg_c, arg_i, arg_c, soft, strict_dict = [], [], [], [], [], {}
1590
+ for bidx, (trg_ids, trg_lb), (arg_k_ids, arg_k_lb) in events:
1591
+ if (bidx, trg_ids) not in trg_i:
1592
+ trg_i.append((bidx, trg_ids))
1593
+
1594
+ if (bidx, (trg_ids, trg_lb)) not in trg_c:
1595
+ trg_c.append((bidx, (trg_ids, trg_lb)))
1596
+
1597
+ if (bidx, trg_ids, arg_k_ids) not in arg_i:
1598
+ arg_i.append((bidx, trg_ids, arg_k_ids))
1599
+
1600
+ if (bidx, trg_ids, (arg_k_ids, arg_k_lb)) not in arg_c:
1601
+ arg_c.append((bidx, trg_ids, (arg_k_ids, arg_k_lb)))
1602
+
1603
+ if (bidx, (trg_ids, trg_lb), (arg_k_ids, arg_k_lb)) not in soft:
1604
+ soft.append((bidx, (trg_ids, trg_lb), (arg_k_ids, arg_k_lb)))
1605
+
1606
+ if bidx not in strict_dict:
1607
+ strict_dict[bidx] = {}
1608
+ if (trg_ids, trg_lb) not in strict_dict[bidx]:
1609
+ strict_dict[bidx][(trg_ids, trg_lb)] = []
1610
+ strict_dict[bidx][(trg_ids, trg_lb)].append((arg_k_ids, arg_k_lb))
1611
+
1612
+ strict = []
1613
+ for bidx, trg_dict in strict_dict.items():
1614
+ for trg, args in trg_dict.items():
1615
+ strict.append((bidx, trg, frozenset(args)))
1616
+
1617
+ save_dict['Trg-I'].extend(trg_i)
1618
+ save_dict['Trg-C'].extend(trg_c)
1619
+ save_dict['Arg-I'].extend(arg_i)
1620
+ save_dict['Arg-C'].extend(arg_c)
1621
+ save_dict['Soft-Event'].extend(soft)
1622
+ save_dict['Strict-Event'].extend(strict)
1623
+
1624
+ def test(network, state_dicts, test_loader, eval_fn, analyzer, device, id2label, tokenizer):
1625
+ if torch.cuda.device_count() > 1:
1626
+ network = DataParallelProxy(network)
1627
+ network = network.to(device)
1628
+ network.eval()
1629
+
1630
+ eval_types = ['Trg-I', 'Trg-C', 'Arg-I', 'Arg-C', 'Soft-Event', 'Strict-Event']
1631
+
1632
+ all_pred = {eval_type: [] for eval_type in eval_types}
1633
+ all_gold = {eval_type: [] for eval_type in eval_types}
1634
+
1635
+ list_input_ids = []
1636
+
1637
+ with torch.no_grad():
1638
+ for batch_idx, batch in enumerate(test_loader):
1639
+ input_ids = batch['input_ids'].to(device)
1640
+ attention_mask = batch['attention_mask'].to(device)
1641
+ gold_events = batch['gold_events']
1642
+
1643
+ B, _, _ = input_ids.shape
1644
+ list_input_ids.extend(input_ids.reshape(B, -1).tolist())
1645
+
1646
+ list_trg_logits = []
1647
+ list_hidden_states = []
1648
+ list_arg_logits = []
1649
+
1650
+ for sd in state_dicts:
1651
+ if torch.cuda.device_count() > 1:
1652
+ network.module.load_state_dict(sd)
1653
+ else:
1654
+ network.load_state_dict(sd)
1655
+
1656
+ hidden_states = network.encode(input_ids, attention_mask)
1657
+ trg_logits = network.get_trg_logits(hidden_states)
1658
+ list_trg_logits.append(trg_logits)
1659
+ list_hidden_states.append(hidden_states)
1660
+
1661
+ ensemble_trg_logits = torch.stack(list_trg_logits, dim=0).mean(dim=0)
1662
+ pred_trg_labels = torch.argmax(ensemble_trg_logits, dim=-1)
1663
+ pred_trg_labels = fix_bio_ids_batch(pred_trg_labels)
1664
+ pred_trg_spans = extract_trigger_spans_batch_tensor(pred_trg_labels)
1665
+
1666
+ for sd, hidden_states in zip(state_dicts, list_hidden_states):
1667
+ if torch.cuda.device_count() > 1:
1668
+ network.module.load_state_dict(sd)
1669
+ else:
1670
+ network.load_state_dict(sd)
1671
+
1672
+ trg_repr = get_span_repr(hidden_states, pred_trg_spans) # B, N, 4H
1673
+ trg_repr = network.trg_repr_proj(trg_repr) # B, N, H
1674
+ arg_logits = network.get_arg_logits(hidden_states, trg_repr)
1675
+
1676
+ list_arg_logits.append(arg_logits)
1677
+
1678
+ ensemble_arg_logits = torch.stack(list_arg_logits, dim=0).mean(dim=0)
1679
+
1680
+ pred_events = extract_arguments(input_ids.reshape(B, -1), ensemble_trg_logits, ensemble_arg_logits, pred_trg_spans, id2label)
1681
+ pred_events = shift_bidx(pred_events, batch_idx)
1682
+ refactor_events(pred_events, all_pred)
1683
+
1684
+ gold_events = shift_bidx(gold_events, batch_idx)
1685
+ refactor_events(gold_events, all_gold)
1686
+
1687
+ # ===== GLOBAL EVAL =====
1688
+ final_score = {}
1689
+ for eval_type in eval_types:
1690
+ score = eval_fn(list_to_tuple(all_pred[eval_type]), list_to_tuple(all_gold[eval_type]))
1691
+ final_score[eval_type] = score
1692
+
1693
+ analyze_result = analyzer.analyze(list_to_tuple(all_pred['Trg-I']), list_to_tuple(all_gold['Trg-I']))
1694
+
1695
+ # ===== PREDICT =====
1696
+ predictions = []
1697
+ for input_ids in list_input_ids:
1698
+ predictions.append([tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)])
1699
+ for event in all_pred['Strict-Event']:
1700
+ bidx = event[0]
1701
+ trg = tokenizer.decode(event[1][0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
1702
+ trg_lb = event[1][1]
1703
+ predictions[bidx].append((trg, trg_lb))
1704
+
1705
+ for arg_infor in event[2]:
1706
+ arg = tokenizer.decode(arg_infor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
1707
+ arg_lb = arg_infor[1]
1708
+ predictions[bidx].append((arg, arg_lb))
1709
+
1710
+ return final_score, analyze_result, predictions
1711
+
1712
+ # %% [code]
1713
+ with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f:
1714
+ data_train = json.load(f)
1715
+
1716
+ with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f:
1717
+ data_test = json.load(f)
1718
+
1719
+ print('Train:', len(data_train))
1720
+ print('Test:', len(data_test))
1721
+
1722
+ # %% [code]
1723
+ trigger_types = sorted(list(set([e['label'] for d in data_train + data_test for e in d['events']]))) # NBR : Neighbor relation
1724
+ bio_trigger_types = ['O'] + [f'{prefix}-{trg}' for trg in trigger_types for prefix in ['B', 'I']]
1725
+ trigger_label2id = {l: i for i, l in enumerate(bio_trigger_types)}
1726
+ trigger_id2label = {i: l for l, i in trigger_label2id.items()}
1727
+
1728
+ argument_types = sorted(list(set([a['role'] for d in data_train + data_test for e in d['events'] for a in e['arguments']])))
1729
+ bio_argument_types = ['O'] + [f'{prefix}-{arg}' for arg in argument_types for prefix in ['B', 'I']]
1730
+ argument_label2id = {l: i for i, l in enumerate(bio_argument_types)}
1731
+ argument_id2label = {i: l for l, i in argument_label2id.items()}
1732
+
1733
+ label2id = {
1734
+ 'Trg': trigger_label2id,
1735
+ 'Arg': argument_label2id,
1736
+ }
1737
+
1738
+ id2label = {
1739
+ 'Trg': trigger_id2label,
1740
+ 'Arg': argument_id2label,
1741
+ }
1742
+
1743
+ # %% [code]
1744
+ zero_events_idxes = []
1745
+ for idx, d in enumerate(data_train):
1746
+ if len(d['events']) == 0:
1747
+ zero_events_idxes.append(idx)
1748
+
1749
+ n_zero_events_samples = len(zero_events_idxes)
1750
+ n_has_events_samples = len(data_train) - n_zero_events_samples
1751
+
1752
+ random.seed(42)
1753
+ k = min(int(n_has_events_samples * zero_events_rate), len(zero_events_idxes))
1754
+ sampled_zero_events_idxes = random.sample(zero_events_idxes, k)
1755
+
1756
+ new_data_train = []
1757
+ for idx, d in enumerate(data_train):
1758
+ if len(d['events']) == 0:
1759
+ if idx in sampled_zero_events_idxes:
1760
+ new_data_train.append(d)
1761
+ else:
1762
+ new_data_train.append(d)
1763
+ data_train = new_data_train
1764
+
1765
+ print('Train:', len(data_train))
1766
+
1767
+ # %% [code]
1768
+ if debug_only:
1769
+ data_train = data_train[:20]
1770
+ data_test = data_test[:20]
1771
+
1772
+ print('Train:', len(data_train))
1773
+ print('Test:', len(data_test))
1774
+
1775
+ # %% [code]
1776
+ tokenizer = AutoTokenizer.from_pretrained(backbone_model_name)
1777
+
1778
+ # %% [code]
1779
+ print('Experiment name:', state_dict_save_name)
1780
+
1781
+ # %% [code]
1782
+ if not test_only:
1783
+ full_idxes = np.array(range(len(data_train)))
1784
+ training_logs, best_models, last_models = [], [], []
1785
+ start_training_time = time.time()
1786
+ for seed in SEEDS:
1787
+ kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed)
1788
+ for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)):
1789
+ if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx:
1790
+ continue
1791
+ set_seed(seed)
1792
+
1793
+ train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx]
1794
+
1795
+ trainset = KLTNDataset(data_train, train_idxes, label2id, tokenizer, **train_memory_params)
1796
+ valset = KLTNDataset(data_train, val_idxes, label2id, tokenizer, **val_memory_params)
1797
+
1798
+ generator = torch.Generator()
1799
+ generator.manual_seed(seed)
1800
+ train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params)
1801
+ val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params)
1802
+
1803
+ my_model = IEModel(
1804
+ num_trg_labels=len(trigger_label2id),
1805
+ num_arg_labels=len(argument_label2id),
1806
+ **model_params
1807
+ )
1808
+ total_params = sum(p.numel() for p in my_model.parameters())
1809
+ print(f"Total params: {total_params:,}")
1810
+
1811
+ # optimizer, scheduler = configure_optimizers(my_model, optim_params, scheduler_params)
1812
+ encoder_params = set(map(id, my_model.encoder.parameters()))
1813
+ other_params = [
1814
+ p for p in my_model.parameters()
1815
+ if id(p) not in encoder_params
1816
+ ]
1817
+ optimizer = optim.AdamW([
1818
+ {"params": my_model.encoder.parameters(), "lr": 2e-5},
1819
+ {"params": other_params}
1820
+ ], lr=5e-4)
1821
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
1822
+
1823
+ loss_fn = CustomLoss(
1824
+ **loss_func_params
1825
+ )
1826
+ eval_fn = CustomEvalFn(**eval_func_params)
1827
+ trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}'
1828
+ trainer = Trainer(**trainer_params)
1829
+
1830
+ print(f'Start Training Fold {fold_idx}...')
1831
+ training_log, best_model, last_model = trainer.fit(
1832
+ my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, eval_fn,
1833
+ start_epoch=1, start_training_time=start_training_time, id2label=id2label
1834
+ )
1835
+
1836
+ training_logs.append(training_log)
1837
+ best_models.append(best_model)
1838
+ last_models.append(last_model)
1839
+
1840
+ # %% [code]
1841
+ def load_all_state_dicts(folder):
1842
+ files = []
1843
+
1844
+ for file in os.listdir(folder):
1845
+ if file.endswith(".pt") or file.endswith(".pth"):
1846
+ m = re.search(r"f(\d+)", file) # tìm f<số>
1847
+ if m:
1848
+ fold = int(m.group(1))
1849
+ files.append((fold, file))
1850
+
1851
+ # sort theo fold
1852
+ files.sort(key=lambda x: x[0])
1853
+
1854
+ state_dicts = []
1855
+ for fold, file in files:
1856
+ path = os.path.join(folder, file)
1857
+ print(f"Loading fold {fold}: {file}")
1858
+ state_dict = torch.load(path, map_location="cpu")
1859
+ state_dicts.append(state_dict)
1860
+
1861
+ return state_dicts
1862
+
1863
+ if test_only:
1864
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"])
1865
+ get_ipython().system('rm -rf .cache .gitattributes')
1866
+
1867
+ best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s")
1868
+ last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts")
1869
+
1870
+ # %% [code]
1871
+ os.makedirs(f'{checkpoints_dir}/results', exist_ok=True)
1872
+ testset = KLTNDataset(data_test, range(len(data_test)), label2id, tokenizer, **val_memory_params)
1873
+ generator = torch.Generator()
1874
+ test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params)
1875
+ eval_fn = CustomEvalFn(**eval_func_params)
1876
+ analyzer = SpanErrorAnalyzer()
1877
+ my_model = IEModel(
1878
+ num_trg_labels=len(trigger_label2id),
1879
+ num_arg_labels=len(argument_label2id),
1880
+ **model_params
1881
+ )
1882
+ total_params = sum(p.numel() for p in my_model.parameters())
1883
+ print(f"Total params: {total_params:,}")
1884
+
1885
+ # %% [code]
1886
+ start_time = time.time()
1887
+
1888
+ best_score, best_analyze_result, best_pred_test = test(my_model, best_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
1889
+ last_score, last_analyze_result, last_pred_test = test(my_model, last_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
1890
+
1891
+ result_test = {"Best model": best_score, "Last model": last_score}
1892
+ analyze_result = {"Best model": best_analyze_result, "Last model": last_analyze_result}
1893
+ analyze_result_sumary = {"Best model": best_analyze_result['summary'], "Last model": last_analyze_result['summary']}
1894
+ pred_test = {"Best model": best_pred_test, "Last model": last_pred_test}
1895
+
1896
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f:
1897
+ json.dump(result_test, f, ensure_ascii=False, indent=2)
1898
+
1899
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_error_analyze_result.json", "w", encoding="utf-8") as f:
1900
+ json.dump(analyze_result, f, ensure_ascii=False, indent=2)
1901
+
1902
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_pred_test.json", "w", encoding="utf-8") as f:
1903
+ json.dump(pred_test, f, ensure_ascii=False, indent=2)
1904
+
1905
+ print('Test:', time.time() - start_time, 's --> Done!')
1906
+ print(json.dumps(analyze_result_sumary, ensure_ascii=False, indent=4))
1907
+
1908
+ # %% [code]
1909
+ best_pred_test[:10]
1910
+
1911
+ # %% [code]
1912
+ last_pred_test[:10]
1913
+
1914
+ # %% [code]
1915
+ def dict_to_df(data):
1916
+ row_tuples = []
1917
+ row_values = []
1918
+
1919
+ metrics = ["precision", "recall", "f1"]
1920
+
1921
+ # Lấy model đầu tiên
1922
+ first_model = next(iter(data.values()))
1923
+
1924
+ # eval_keys
1925
+ eval_keys = list(first_model.keys())
1926
+
1927
+ for eval_key in eval_keys:
1928
+ row_tuples.append(eval_key)
1929
+ row = {}
1930
+
1931
+ for model_name, model_data in data.items():
1932
+ for metric in metrics:
1933
+ row[(model_name, metric)] = model_data[eval_key][metric]
1934
+
1935
+ row_values.append(row)
1936
+
1937
+ # ===== DataFrame =====
1938
+ df = pd.DataFrame(row_values)
1939
+
1940
+ # MultiIndex columns
1941
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
1942
+
1943
+ # Index
1944
+ df.index = pd.Index(row_tuples, name="evaluation")
1945
+
1946
+ # ===== Sort =====
1947
+ sort_keys = []
1948
+ if ("Best model", "f1") in df.columns:
1949
+ sort_keys.append(("Best model", "f1"))
1950
+ if ("Last model", "f1") in df.columns:
1951
+ sort_keys.append(("Last model", "f1"))
1952
+
1953
+ if sort_keys:
1954
+ df = df.sort_values(by=sort_keys, ascending=False)
1955
+
1956
+ return df
1957
+
1958
+ result_test_df = dict_to_df(result_test)
1959
+ result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx")
1960
+ result_test_df
1961
+
1962
+ # %% [code]
1963
+ key = ("Best model", "f1")
1964
+ result_test_df_best = result_test_df.sort_values(by=key, ascending=False).groupby(level="evaluation").head(1)
1965
+ result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx")
1966
+ result_test_df_best
1967
+
1968
+ # %% [code]
1969
+ def get_avg_best_score(logs):
1970
+ return float(np.mean([list(log.values())[-1]['best_score'] for log in logs]))
1971
+
1972
+ def get_avg_log(logs, epochs):
1973
+ avg_log = {}
1974
+
1975
+ for epoch in range(1, epochs + 1):
1976
+ val_score = 0.0
1977
+ train_loss = 0.0
1978
+ n_eval = 0
1979
+
1980
+ for idx in range(len(logs)):
1981
+ log = logs[idx].get(epoch, logs[idx].get(str(epoch)))
1982
+ if log is None:
1983
+ continue
1984
+
1985
+ val_score += log.get('val_score', 0.0)
1986
+ train_loss += log.get('train_loss', 0.0)
1987
+ n_eval += 1
1988
+
1989
+ if n_eval == 0:
1990
+ continue
1991
+
1992
+ avg_log[epoch] = {
1993
+ 'train_loss': train_loss / n_eval,
1994
+ 'val_score': val_score / n_eval if val_score != 0 else float('inf')
1995
+ }
1996
+
1997
+ return avg_log
1998
+
1999
+ def parse_label_key(label: str):
2000
+ try:
2001
+ first = float(label.split('_', 1)[0]) # số đầu: trước dấu _
2002
+ last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0])
2003
+ return first, last
2004
+ except:
2005
+ return (0, 0)
2006
+
2007
+ def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)):
2008
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
2009
+
2010
+ # ===== Plot Train Loss =====
2011
+ for name, log in logs_dict.items():
2012
+ epochs = sorted(log.keys())
2013
+ train_loss = [log[e]['train_loss'] for e in epochs]
2014
+ axes[0].plot(epochs, train_loss, label=name)
2015
+
2016
+ axes[0].set_xlabel('Epoch')
2017
+ axes[0].set_ylabel('Train Loss')
2018
+ axes[0].set_title('Training Loss')
2019
+ axes[0].grid(True)
2020
+
2021
+ # ===== Plot Validation Score =====
2022
+ for name, log in logs_dict.items():
2023
+ epochs = sorted(log.keys())
2024
+ val_score = [log[e]['val_score'] for e in epochs]
2025
+ axes[1].plot(epochs, val_score, label=name)
2026
+
2027
+ axes[1].set_xlabel('Epoch')
2028
+ axes[1].set_ylabel('Validation Score')
2029
+ axes[1].set_title('Validation Score')
2030
+ axes[1].grid(True)
2031
+
2032
+ # ===== Shared Legend =====
2033
+ handles, labels = axes[0].get_legend_handles_labels()
2034
+ pairs = list(zip(handles, labels))
2035
+ pairs_sorted = sorted(
2036
+ pairs,
2037
+ key=lambda x: parse_label_key(x[1])
2038
+ )
2039
+ handles_sorted, labels_sorted = zip(*pairs_sorted)
2040
+
2041
+ axes[0].legend(
2042
+ handles_sorted,
2043
+ labels_sorted,
2044
+ loc='center left',
2045
+ bbox_to_anchor=(1.01, 0.5),
2046
+ borderaxespad=0.
2047
+ )
2048
+
2049
+ plt.tight_layout(rect=[0, 0, 1, 1])
2050
+
2051
+ if save_path is not None:
2052
+ os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None
2053
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
2054
+
2055
+ plt.show()
2056
+
2057
+ # %% [code]
2058
+ if not test_only:
2059
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*.json"], ignore_patterns=["5_score_span_base_12/**"])
2060
+ get_ipython().system('rm -rf .cache .gitattributes')
2061
+
2062
+ # %% [code]
2063
+ if not test_only:
2064
+ experiments = {}
2065
+ for experiment in os.listdir(pretrained_dir):
2066
+ try:
2067
+ experiment_logs = []
2068
+ for seed in SEEDS:
2069
+ for fold_idx in range(nfolds):
2070
+ with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f:
2071
+ experiment_log = json.load(f)
2072
+ experiment_logs.append(experiment_log)
2073
+ experiments[experiment] = get_avg_log(experiment_logs, 1000)
2074
+ except:
2075
+ pass
2076
+ experiments[state_dict_save_name] = get_avg_log(training_logs, 1000)
2077
+
2078
+ # %% [code]
2079
+ if not test_only:
2080
+ score = get_avg_best_score(training_logs)
2081
+ state_dict_save_name, score
2082
+
2083
+ # %% [code]
2084
+ if not test_only:
2085
+ plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5))
2086
+
0_token_base_issue_1/lasts/0_token_base_issue_1_s26092004_f0_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d051cd5cbf2554478ab29da795bedc0a905452a4ba2288d74349529a643653b
3
+ size 559060868
0_token_base_issue_1/logs/0_token_base_issue_1_log_plot.jpg ADDED

Git LFS Details

  • SHA256: 876dedb7c6f6d40c5878d4ffa0af2e537a6a992d05bc17cd8f960dab44734763
  • Pointer size: 131 Bytes
  • Size of remote file: 478 kB
0_token_base_issue_1/logs/0_token_base_issue_1_s26092004_f0_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 0.0005], "train_loss": 1.0860533714294434, "total": 1.0860533494215745, "trg_loss": 0.2625841874342698, "arg_loss": 0.8234687805175781}, "2": {"lr": [1.988303923565381e-05, 0.0004969282409784868], "train_loss": 0.7911608219146729, "total": 0.7911608182466947, "trg_loss": 0.18240927182711086, "arg_loss": 0.6087513850285456}, "3": {"lr": [1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 0.7209492921829224, "total": 0.7209493196927584, "trg_loss": 0.16584016359769382, "arg_loss": 0.5551097576434796}, "4": {"lr": [1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 0.6749942302703857, "total": 0.6749942486102765, "trg_loss": 0.1577667236328125, "arg_loss": 0.5172272315392128}, "5": {"lr": [1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 0.6352885961532593, "total": 0.6352885906512921, "trg_loss": 0.1492153314443735, "arg_loss": 0.48607300978440504}, "6": {"lr": [1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 0.6105099320411682, "total": 0.6105099017803486, "trg_loss": 0.14334932473989634, "arg_loss": 0.4671605917123648}, "7": {"lr": [1.60839598967785e-05, 0.00039715242044697206], "train_loss": 0.5796421766281128, "total": 0.5796421931340144, "trg_loss": 0.13734904069166917, "arg_loss": 0.44229313777043266}, "8": {"lr": [1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 0.5453628301620483, "total": 0.5453628540039063, "trg_loss": 0.13192813579852763, "arg_loss": 0.41343465951772834, "val_score": 0.008943983975020976, "best_score": 0.008943983975020976, "new_best_model": true, "precision": 0.008171760941926835, "recall": 0.009919884158367686, "f1": 0.008943983975020976}, "9": {"lr": [1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 0.5109521150588989, "total": 0.5109521132249099, "trg_loss": 0.12609261732835036, "arg_loss": 0.38485952524038464, "val_score": 0.00936368127151777, "best_score": 0.00936368127151777, "new_best_model": true, "precision": 0.00858333145103755, "recall": 0.01033654863998915, "f1": 0.00936368127151777}, "10": {"lr": [1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 0.49318966269493103, "total": 0.4931896503155048, "trg_loss": 0.12101867382342998, "arg_loss": 0.3721710791954627, "val_score": 0.009368108255106934, "best_score": 0.009368108255106934, "new_best_model": true, "precision": 0.00859450499868192, "recall": 0.01032669126062992, "f1": 0.009368108255106934}, "11": {"lr": [1.0500000000000003e-05, 0.0002505], "train_loss": 0.4708995223045349, "total": 0.47089952322152945, "trg_loss": 0.11534688656146709, "arg_loss": 0.35555249727689303, "val_score": 0.00919727910098337, "best_score": 0.009368108255106934, "new_best_model": false, "precision": 0.008552713247487152, "recall": 0.009989716584658715, "f1": 0.00919727910098337}, "12": {"lr": [9.013872582117811e-06, 0.00021146960097246246], "train_loss": 0.4356004297733307, "total": 0.4356004274808444, "trg_loss": 0.11027498245239258, "arg_loss": 0.32532527630145736, "val_score": 0.00865802865410624, "best_score": 0.009368108255106934, "new_best_model": false, "precision": 0.008109137799155454, "recall": 0.009331600138932394, "f1": 0.00865802865410624}, "13": {"lr": [7.564338553438001e-06, 0.00017340025990345064], "train_loss": 0.42093297839164734, "total": 0.4209329751821665, "trg_loss": 0.10636664904080904, "arg_loss": 0.3145663334773137, "val_score": 0.008765660856231407, "best_score": 0.009368108255106934, "new_best_model": false, "precision": 0.008370225862331897, "recall": 0.00924011738704956, "f1": 0.008765660856231407}, "14": {"lr": [6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 0.39907169342041016, "total": 0.3990716787484976, "trg_loss": 0.10246439713698167, "arg_loss": 0.29660709087665266, "val_score": 0.008731740800888848, "best_score": 0.009368108255106934, "new_best_model": false, "precision": 0.008428924709651018, "recall": 0.009095223455246106, "f1": 0.008731740800888848}, "15": {"lr": [4.916040103221507e-06, 0.00010384757955302797], "train_loss": 0.3795405626296997, "total": 0.3795405754676232, "trg_loss": 0.09897036185631385, "arg_loss": 0.28057025029109073, "val_score": 0.008555719326435339, "best_score": 0.009368108255106934, "new_best_model": false, "precision": 0.008418719723152402, "recall": 0.00874125076363765, "f1": 0.008555719326435339}, "16": {"lr": [3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 0.3590869605541229, "total": 0.35908696101262016, "trg_loss": 0.09613694411057692, "arg_loss": 0.2629500169020433, "val_score": 0.008900738532043482, "best_score": 0.009368108255106934, "new_best_model": false, "precision": 0.008834576675114208, "recall": 0.009011293311694544, "f1": 0.008900738532043482}, "17": {"lr": [2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 0.3371564745903015, "total": 0.3371564718393179, "trg_loss": 0.09345089839054987, "arg_loss": 0.24370544140155498, "val_score": 0.008781447082213366, "best_score": 0.009368108255106934, "new_best_model": false, "precision": 0.00884134346106322, "recall": 0.008753451655544054, "f1": 0.008781447082213366}, "18": {"lr": [2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 0.3141079843044281, "total": 0.3141079829289363, "trg_loss": 0.09209175109863281, "arg_loss": 0.222016114455003, "val_score": 0.009035306205967566, "best_score": 0.009368108255106934, "new_best_model": false, "precision": 0.009398834151514941, "recall": 0.008748056069134905, "f1": 0.009035306205967566}}
0_token_base_issue_1/r1s/0_token_base_issue_1_s26092004_f0_r1_vs0.00937_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:048c8bb67aef68a3c3edb7af5c864e925a6645c45ea0219e5c8f9bf34e76e5dc
3
+ size 559062604
0_token_base_issue_1/results/0_token_base_issue_1_error_analyze_result.json ADDED
The diff for this file is too large to render. See raw diff
 
0_token_base_issue_1/results/0_token_base_issue_1_pred_test.json ADDED
The diff for this file is too large to render. See raw diff
 
0_token_base_issue_1/results/0_token_base_issue_1_test.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Best model": {
3
+ "Trg-I": {
4
+ "precision": 0.407704341023841,
5
+ "recall": 0.3618535230816351,
6
+ "f1": 0.38341301450950477
7
+ },
8
+ "Trg-C": {
9
+ "precision": 0.01700879765392569,
10
+ "recall": 0.01500776263582024,
11
+ "f1": 0.015945742820090092
12
+ },
13
+ "Arg-I": {
14
+ "precision": 0.20714931695929975,
15
+ "recall": 0.18519937018588653,
16
+ "f1": 0.19556034371691852
17
+ },
18
+ "Arg-C": {
19
+ "precision": 0.1876873305104012,
20
+ "recall": 0.16723896731520316,
21
+ "f1": 0.17687409831637366
22
+ },
23
+ "Soft-Event": {
24
+ "precision": 0.007279128407627728,
25
+ "recall": 0.006462513199574884,
26
+ "f1": 0.006846551602434481
27
+ },
28
+ "Strict-Event": {
29
+ "precision": 0.0035190615835708327,
30
+ "recall": 0.0031050543384455666,
31
+ "f1": 0.003299115254136828
32
+ }
33
+ },
34
+ "Last model": {
35
+ "Trg-I": {
36
+ "precision": 0.4390569395008028,
37
+ "recall": 0.342589378687361,
38
+ "f1": 0.38487034017191707
39
+ },
40
+ "Trg-C": {
41
+ "precision": 0.01645907473305948,
42
+ "recall": 0.01276522339138733,
43
+ "f1": 0.014378699053428494
44
+ },
45
+ "Arg-I": {
46
+ "precision": 0.2148328040349867,
47
+ "recall": 0.14681475807474922,
48
+ "f1": 0.17442741820119895
49
+ },
50
+ "Arg-C": {
51
+ "precision": 0.1974478680359804,
52
+ "recall": 0.1344694561023636,
53
+ "f1": 0.1599838555723224
54
+ },
55
+ "Soft-Event": {
56
+ "precision": 0.007656395891685243,
57
+ "recall": 0.005195353748677848,
58
+ "f1": 0.006190231720845471
59
+ },
60
+ "Strict-Event": {
61
+ "precision": 0.0024466192170764086,
62
+ "recall": 0.0018975332068278464,
63
+ "f1": 0.0021373699948784206
64
+ }
65
+ }
66
+ }
0_token_base_issue_1/results/0_token_base_issue_1_test_df.xlsx ADDED
Binary file (5.66 kB). View file
 
0_token_base_issue_1/results/0_token_base_issue_1_test_df_best.xlsx ADDED
Binary file (5.66 kB). View file