SS3M commited on
Commit
95580b8
·
verified ·
1 Parent(s): 1299746

Upload 1_pointer_base_actions_4's state dict

Browse files
.gitattributes CHANGED
@@ -47,3 +47,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
47
  1_pointer_base_entities_4/logs/1_pointer_base_entities_4_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
48
  1_pointer_base_issues_4/logs/1_pointer_base_issues_4_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
49
  4_doc_level_entities_5/logs/4_doc_level_entities_5_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
 
 
47
  1_pointer_base_entities_4/logs/1_pointer_base_entities_4_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
48
  1_pointer_base_issues_4/logs/1_pointer_base_issues_4_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
49
  4_doc_level_entities_5/logs/4_doc_level_entities_5_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
50
+ 1_pointer_base_actions_4/logs/1_pointer_base_actions_4_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
1_pointer_base_actions_4/1_pointer_base_actions_4.py ADDED
@@ -0,0 +1,2190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [code]
2
+ get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch]')
3
+
4
+ # %% [code]
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
12
+ import torch.nn.functional as F
13
+ import albumentations as albu
14
+ from transformers import AutoTokenizer, AutoModel
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from positional_encodings.torch_encodings import PositionalEncoding1D
18
+
19
+ from sklearn.metrics import f1_score
20
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
21
+ from scipy.spatial.transform import Rotation as R
22
+ from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold
23
+ from sklearn.metrics import precision_recall_fscore_support
24
+ from timm.utils import ModelEmaV3
25
+ import timm
26
+
27
+ import os
28
+ import gc
29
+ import json
30
+ from pathlib import Path
31
+ import pickle
32
+ from tqdm.auto import tqdm
33
+ import copy
34
+ import numpy as np
35
+ import pandas as pd
36
+ import polars as pl
37
+ from PIL import Image
38
+ import time
39
+ from tqdm import tqdm
40
+ from matplotlib import pyplot as plt
41
+ import seaborn as sns
42
+ from multiprocessing import Manager as MemoryManager
43
+ from functools import lru_cache
44
+ import shutil
45
+ import glob
46
+ import cv2
47
+ import random
48
+ import re
49
+ import joblib
50
+ import math
51
+ from huggingface_hub import HfApi, snapshot_download
52
+ import evaluate
53
+ from underthesea import word_tokenize as vi_tokenize_tool
54
+ import spacy
55
+ en_tokenize_tool = spacy.load("en_core_web_sm")
56
+ from collections import defaultdict, Counter
57
+
58
+ # %% [code]
59
+ # Global config
60
+ SEEDS = [26092004]
61
+ topk = 1
62
+ nfolds = 5
63
+ only_fold_idx = 0
64
+ test_only = 0
65
+ debug_only = 0
66
+
67
+ # Config thư mục
68
+ dataset = 'kltn/only_actions' # vhe, bkee, casie, kltn/only_issues, kltn/only_actions
69
+ root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' ## Thư mục chứa file train, val, test
70
+ train_dir = f'{root_dir}'
71
+ # val_dir = f'{root_dir}/val'
72
+ test_dir = f'{root_dir}'
73
+
74
+ # Config checkpoints
75
+
76
+ # Config training
77
+ epochs = 18 if not debug_only else 2
78
+ batch_size = 32
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+ # # Thêm biến toàn cục nào đó vào đây
81
+ repo_name = 'SS3M/kltn-experiments'
82
+ state_dict_save_name = "1_pointer_base_actions_4"
83
+ checkpoints_dir = state_dict_save_name
84
+ pretrained_dir = "/kaggle/working"
85
+ os.makedirs(f'{checkpoints_dir}', exist_ok=True)
86
+
87
+ backbone_model_name = "bert-base-uncased" if dataset == "casie" else "vinai/phobert-base"
88
+ word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == "casie" else vi_tokenize_tool(text)
89
+ max_len_dict = {
90
+ 'kltn/only_issues': 52,
91
+ 'kltn/only_actions': 69,
92
+ 'vhe': 51,
93
+ 'bkee': 62,
94
+ 'casie': 40,
95
+ }
96
+ zero_events_rate_dict = {
97
+ 'kltn/only_issues': 0,
98
+ 'kltn/only_actions': 0.2,
99
+ 'vhe': 1000, # mean keep all zero-events samples
100
+ 'bkee': 1000,
101
+ 'casie': 1000,
102
+ }
103
+
104
+ max_len = max_len_dict[dataset]
105
+ max_n_parts = 1
106
+ max_span_len = 14
107
+ zero_events_rate = zero_events_rate_dict[dataset]
108
+
109
+ # Trainer
110
+ trainer_params = {
111
+ "training_time": "00:11:30:00",
112
+ "eval_mode": "max",
113
+ "topk": topk,
114
+ "save_name": state_dict_save_name,
115
+ "save_best": True,
116
+ "save_last": True,
117
+ "device": device,
118
+ "logging": True,
119
+ "logging_file": True,
120
+ "checkpoints_dir": checkpoints_dir,
121
+ "early_stopping": 30,
122
+ "eval_from_ratio": 0.4,
123
+ "eval_every": 1,
124
+ "schedule_in_step": False,
125
+ "use_ema": True,
126
+ "ema_from_ratio": 0.3,
127
+ "ema_decay": 0.9995,
128
+ "max_grad_norm": 200.0,
129
+ "return_best": True,
130
+ "return_last": True,
131
+ }
132
+
133
+ # Memory
134
+ train_memory_params = {
135
+ 'max_len': max_len,
136
+ 'max_n_parts': max_n_parts,
137
+ }
138
+ val_memory_params = {
139
+ 'max_len': max_len,
140
+ 'max_n_parts': max_n_parts,
141
+ }
142
+
143
+ # Data Loader
144
+ def seed_worker(worker_id):
145
+ worker_seed = torch.initial_seed() % 2**32
146
+ np.random.seed(worker_seed)
147
+ random.seed(worker_seed)
148
+
149
+ train_loader_params = {
150
+ 'batch_size': batch_size,
151
+ 'shuffle': True,
152
+ 'pin_memory':True,
153
+ 'num_workers': 2,
154
+ 'drop_last': False,
155
+ 'worker_init_fn': seed_worker,
156
+ 'persistent_workers': False,
157
+ }
158
+ val_loader_params = {
159
+ 'batch_size': batch_size,
160
+ 'shuffle': False,
161
+ 'pin_memory':True,
162
+ 'num_workers': 1,
163
+ 'drop_last': False,
164
+ 'worker_init_fn': seed_worker,
165
+ 'persistent_workers': False,
166
+ }
167
+
168
+ # Model
169
+ model_params = {
170
+ 'backbone_model_name': backbone_model_name,
171
+ }
172
+
173
+ # Loss Func
174
+ loss_func_params = {
175
+ 'lambda_trg_ce': 1.0,
176
+ 'lambda_arg_ce': 1.0,
177
+ }
178
+ eval_func_params = {}
179
+
180
+ # Optim
181
+ optim_params = {
182
+ 'name': 'AdamW',
183
+ 'lr': 1e-4,
184
+ 'weight_decay': 1e-4,
185
+ }
186
+ scheduler_params = {
187
+ 'name': 'CosineAnnealingLR',
188
+ 'T_max': 20, # Số epoch để hoàn thành một chu kỳ giảm LR
189
+ 'eta_min': 1e-6 # Learning rate nhỏ nhất trong chu kỳ
190
+ }
191
+
192
+ # %% [code]
193
+ def set_seed(seed=42):
194
+ random.seed(seed)
195
+ np.random.seed(seed)
196
+ torch.manual_seed(seed)
197
+ torch.cuda.manual_seed(seed)
198
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
199
+ torch.use_deterministic_algorithms(False)
200
+ torch.backends.cudnn.deterministic = True
201
+ torch.backends.cudnn.benchmark = False
202
+ os.environ['PYTHONHASHSEED'] = str(seed)
203
+
204
+ # %% [code]
205
+ class CustomLoss(nn.Module):
206
+ def __init__(
207
+ self,
208
+ lambda_trg_ce=1.0,
209
+ lambda_arg_ce=1.0,
210
+ ):
211
+ super().__init__()
212
+ self.lambda_trg_ce = lambda_trg_ce
213
+ self.lambda_arg_ce = lambda_arg_ce
214
+ self.ce = nn.CrossEntropyLoss(ignore_index=-100)
215
+
216
+ def forward(
217
+ self,
218
+ trg_start_logits, trg_start_labels,
219
+ trg_end_logits, trg_end_labels,
220
+ arg_start_logits, pred_arg_start_labels,
221
+ arg_end_logits, pred_arg_end_labels,
222
+ ):
223
+ device = trg_start_logits.device
224
+
225
+ # ===== TRG START CE =====
226
+ B, N, C = trg_start_logits.shape
227
+
228
+ trg_start_logits_flat = trg_start_logits.view(B * N, C)
229
+ trg_start_labels_flat = trg_start_labels.view(-1)
230
+
231
+ trg_start_loss = self.ce(
232
+ trg_start_logits_flat,
233
+ trg_start_labels_flat
234
+ )
235
+
236
+ # ===== TRG END CE =====
237
+ B, N, C = trg_end_logits.shape
238
+
239
+ trg_end_logits_flat = trg_end_logits.view(B * N, C)
240
+ trg_end_labels_flat = trg_end_labels.view(-1)
241
+
242
+ trg_end_loss = self.ce(
243
+ trg_end_logits_flat,
244
+ trg_end_labels_flat
245
+ )
246
+
247
+ # ===== ARG CE =====
248
+ B, K, M, C = arg_start_logits.shape
249
+ arg_start_logits_flat = arg_start_logits.view(B * K * M, C)
250
+ arg_start_labels_flat = pred_arg_start_labels.view(-1)
251
+
252
+ arg_mask = (arg_start_labels_flat != -100)
253
+
254
+ if arg_mask.sum() == 0:
255
+ arg_start_loss = torch.tensor(0.0, device=device)
256
+ else:
257
+ arg_start_loss = self.ce(arg_start_logits_flat, arg_start_labels_flat) # (B*K*M,)
258
+
259
+ B, K, M, C = arg_end_logits.shape
260
+ arg_end_logits_flat = arg_end_logits.view(B * K * M, C)
261
+ arg_end_labels_flat = pred_arg_end_labels.view(-1)
262
+
263
+ arg_mask = (arg_end_labels_flat != -100)
264
+
265
+ if arg_mask.sum() == 0:
266
+ arg_end_loss = torch.tensor(0.0, device=device)
267
+ else:
268
+ arg_end_loss = self.ce(arg_end_logits_flat, arg_end_labels_flat) # (B*K*M,)
269
+
270
+ # ===== TOTAL =====
271
+ total_loss = (
272
+ self.lambda_trg_ce * (trg_start_loss + trg_end_loss) +
273
+ self.lambda_arg_ce * (arg_start_loss + arg_end_loss)
274
+ )
275
+
276
+ return {
277
+ "total": total_loss,
278
+ "trg_start_loss": trg_start_loss,
279
+ "trg_end_loss": trg_end_loss,
280
+ "arg_start_loss": arg_start_loss,
281
+ "arg_end_loss": arg_end_loss,
282
+ }
283
+
284
+ # %% [code]
285
+ ## Viết eval_fn vào đây
286
+
287
+ # Bỏ hết eval_fn và trọng số vào đây
288
+ class CustomEvalFn(nn.Module):
289
+ def __init__(self):
290
+ super().__init__()
291
+
292
+ def compute_f1(self, tp, fp, fn):
293
+ precision = tp / (tp + fp + 1e-8)
294
+ recall = tp / (tp + fn + 1e-8)
295
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
296
+ return precision, recall, f1
297
+
298
+ def forward(self, pred, gold):
299
+ pred_set = set(pred)
300
+ gold_set = set(gold)
301
+
302
+ tp = len(pred_set & gold_set)
303
+ fp = len(pred_set - gold_set)
304
+ fn = len(gold_set - pred_set)
305
+
306
+ precision, recall, f1 = self.compute_f1(tp, fp, fn)
307
+
308
+ return {
309
+ f"precision": precision,
310
+ f"recall": recall,
311
+ f"f1": f1,
312
+ }
313
+
314
+ class SpanErrorAnalyzer:
315
+ def __init__(self, pad_token_id=0):
316
+ self.pad_token_id = pad_token_id
317
+
318
+ # ===== helper =====
319
+ def _to_set(self, data):
320
+ """
321
+ data: list of (b, tuple(ids))
322
+ -> dict[b] = set(tuple(ids))
323
+ """
324
+ res = defaultdict(set)
325
+ for b, ids in data:
326
+ ids = tuple([i for i in ids if i != self.pad_token_id])
327
+ if len(ids) > 0:
328
+ res[b].add(ids)
329
+ return res
330
+
331
+ def _iou(self, a, b):
332
+ """
333
+ a, b: tuple(ids)
334
+ """
335
+ set_a, set_b = set(a), set(b)
336
+ inter = len(set_a & set_b)
337
+ union = len(set_a | set_b)
338
+ if union == 0:
339
+ return 0.0
340
+ return inter / union
341
+
342
+ def _boundary_error(self, pred, gold):
343
+ """
344
+ đo lệch boundary dựa trên overlap prefix/suffix
345
+ """
346
+ # left match
347
+ left = 0
348
+ for i in range(min(len(pred), len(gold))):
349
+ if pred[i] == gold[i]:
350
+ left += 1
351
+ else:
352
+ break
353
+
354
+ # right match
355
+ right = 0
356
+ for i in range(1, min(len(pred), len(gold)) + 1):
357
+ if pred[-i] == gold[-i]:
358
+ right += 1
359
+ else:
360
+ break
361
+
362
+ return {
363
+ "left_match": left,
364
+ "right_match": right,
365
+ "pred_len": len(pred),
366
+ "gold_len": len(gold),
367
+ }
368
+
369
+ # ===== main =====
370
+ def analyze(self, preds, golds):
371
+ pred_map = self._to_set(preds)
372
+ gold_map = self._to_set(golds)
373
+
374
+ all_batches = set(pred_map.keys()) | set(gold_map.keys())
375
+
376
+ stats = Counter()
377
+
378
+ detailed_errors = []
379
+
380
+ for b in all_batches:
381
+ pset = pred_map.get(b, set())
382
+ gset = gold_map.get(b, set())
383
+
384
+ matched_gold = set()
385
+
386
+ # ===== check predictions =====
387
+ for p in pset:
388
+ if p in gset:
389
+ stats["exact_match"] += 1
390
+ matched_gold.add(p)
391
+ else:
392
+ # tìm gold gần nhất
393
+ best_iou = 0
394
+ best_g = None
395
+
396
+ for g in gset:
397
+ iou = self._iou(p, g)
398
+ if iou > best_iou:
399
+ best_iou = iou
400
+ best_g = g
401
+
402
+ if best_iou > 0:
403
+ stats["partial_match"] += 1
404
+
405
+ boundary = self._boundary_error(p, best_g)
406
+
407
+ detailed_errors.append({
408
+ "type": "boundary_error",
409
+ "batch": b,
410
+ "pred": p,
411
+ "gold": best_g,
412
+ "iou": best_iou,
413
+ **boundary
414
+ })
415
+ else:
416
+ if b not in gold_map:
417
+ stats["no_event_sample"] += 1
418
+ err_type = "no_event_sample"
419
+ else:
420
+ stats["completely_wrong"] += 1
421
+ err_type = "completely_wrong"
422
+
423
+ detailed_errors.append({
424
+ "type": err_type,
425
+ "batch": b,
426
+ "pred": p
427
+ })
428
+
429
+ # ===== check missing =====
430
+ for g in gset:
431
+ if g not in matched_gold:
432
+ # check if any pred overlaps
433
+ overlap = any(self._iou(p, g) > 0 for p in pset)
434
+
435
+ if overlap:
436
+ stats["miss_with_overlap"] += 1
437
+ else:
438
+ stats["miss"] += 1
439
+
440
+ detailed_errors.append({
441
+ "type": "miss",
442
+ "batch": b,
443
+ "gold": g
444
+ })
445
+
446
+ return {
447
+ "summary": {
448
+ "exact_match": (stats["exact_match"], stats["exact_match"] / len(preds)),
449
+ "partial_match": (stats["partial_match"], stats["partial_match"] / len(preds)),
450
+ "no_event_sample": (stats["no_event_sample"], stats["no_event_sample"] / len(preds)),
451
+ "completely_wrong": (stats["completely_wrong"], stats["completely_wrong"] / len(preds)),
452
+ "miss": (stats["miss"], stats["miss"] / len(golds)),
453
+ "miss_with_overlap": (stats["miss_with_overlap"], stats["miss_with_overlap"] / len(golds)),
454
+ },
455
+ "details": detailed_errors
456
+ }
457
+
458
+ # %% [code]
459
+ ## Viết cấu trúc model vào đây
460
+ def fix_bio_ids_batch(label_ids):
461
+ """
462
+ label_ids: (B, L)
463
+ return: (B, L) fixed
464
+ """
465
+ B, L = label_ids.shape
466
+ fixed = label_ids.clone()
467
+
468
+ for b in range(B):
469
+ for i in range(L):
470
+ tag = fixed[b, i].item()
471
+
472
+ if tag == 0:
473
+ continue
474
+
475
+ # I- (even)
476
+ if tag % 2 == 0:
477
+ if i == 0 or fixed[b, i-1].item() == 0:
478
+ fixed[b, i] = tag - 1
479
+ else:
480
+ prev_tag = fixed[b, i-1].item()
481
+
482
+ if prev_tag == 0:
483
+ fixed[b, i] = tag - 1
484
+ else:
485
+ prev_type = (prev_tag - 1) // 2
486
+ curr_type = (tag - 1) // 2
487
+
488
+ if prev_type != curr_type:
489
+ fixed[b, i] = tag - 1
490
+
491
+ return fixed
492
+
493
+ def extract_trigger_spans_batch_tensor(label_ids):
494
+ """
495
+ label_ids: (B, L)
496
+ return:
497
+ spans_tensor: (B, N, 2) # (s, e), pad = (0,0)
498
+ """
499
+ B, L = label_ids.shape
500
+ device = label_ids.device
501
+
502
+ all_spans = []
503
+ max_n = 0
504
+
505
+ # ===== extract spans (list trước) =====
506
+ for b in range(B):
507
+ spans = []
508
+ i = 0
509
+
510
+ while i < L:
511
+ tag = label_ids[b, i].item()
512
+
513
+ if tag == 0:
514
+ i += 1
515
+ continue
516
+
517
+ # B- (odd)
518
+ if tag % 2 == 1:
519
+ type_id = (tag - 1) // 2
520
+ s = i
521
+ e = i
522
+ i += 1
523
+
524
+ while i < L:
525
+ next_tag = label_ids[b, i].item()
526
+
527
+ if next_tag == 0:
528
+ break
529
+
530
+ next_type = (next_tag - 1) // 2
531
+
532
+ if next_tag % 2 == 0 and next_type == type_id:
533
+ e = i
534
+ i += 1
535
+ else:
536
+ break
537
+
538
+ spans.append((s, e))
539
+ else:
540
+ i += 1
541
+
542
+ all_spans.append(spans)
543
+ max_n = max(max_n, len(spans))
544
+
545
+ # ===== build tensor =====
546
+ if max_n == 0:
547
+ # không có span nào → return tensor rỗng đúng shape
548
+ return torch.zeros((B, 0, 2), dtype=torch.long, device=device)
549
+
550
+ spans_tensor = torch.zeros((B, max_n, 2), dtype=torch.long, device=device)
551
+
552
+ for b in range(B):
553
+ for i, (s, e) in enumerate(all_spans[b]):
554
+ spans_tensor[b, i, 0] = s
555
+ spans_tensor[b, i, 1] = e
556
+
557
+ return spans_tensor
558
+
559
+ def get_span_repr(hidden, spans):
560
+ B, L, H = hidden.size()
561
+ K = spans.size(1)
562
+ device = hidden.device
563
+
564
+ start = spans[:, :, 0] # (B, K)
565
+ end = spans[:, :, 1] # (B, K)
566
+
567
+ h_s = torch.gather(hidden, 1, start.unsqueeze(-1).expand(-1, -1, H))
568
+ h_e = torch.gather(hidden, 1, end.unsqueeze(-1).expand(-1, -1, H))
569
+
570
+ h_diff = h_s - h_e
571
+ h_prod = h_s * h_e
572
+
573
+ # ===== 6. concat =====
574
+ span_repr = torch.cat(
575
+ [h_s, h_e, h_diff, h_prod],
576
+ dim=-1
577
+ )
578
+
579
+ return span_repr
580
+
581
+ class MLP(nn.Module):
582
+ def __init__(self, in_size, hid_size, out_size):
583
+ super().__init__()
584
+ self.model = nn.Sequential(
585
+ nn.Linear(in_size, hid_size),
586
+ nn.ReLU(),
587
+ nn.Linear(hid_size, out_size)
588
+ )
589
+
590
+ def forward(self, x):
591
+ return self.model(x)
592
+
593
+ class IEModel(nn.Module):
594
+ def __init__(self, backbone_model_name, num_trg_labels, num_arg_labels):
595
+ super().__init__()
596
+ self.encoder = AutoModel.from_pretrained(backbone_model_name)
597
+ hidden_size = self.encoder.config.hidden_size
598
+
599
+ self.trg_start_classifier = MLP(hidden_size, hidden_size, num_trg_labels)
600
+ self.trg_end_classifier = MLP(hidden_size, hidden_size, num_trg_labels)
601
+
602
+ self.trg_repr_proj = MLP(hidden_size*4, hidden_size, hidden_size)
603
+ self.arg_start_classifier = MLP(hidden_size*2, hidden_size, num_arg_labels)
604
+ self.arg_end_classifier = MLP(hidden_size*2, hidden_size, num_arg_labels)
605
+
606
+ def encode(self, input_ids, attention_mask):
607
+ B, n_parts, L = input_ids.shape
608
+ input_ids = input_ids.view(-1, L)
609
+ attention_mask = attention_mask.view(-1, L)
610
+
611
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
612
+ hidden_states = outputs.last_hidden_state # B * n_parts, L, H
613
+
614
+ hidden_states = hidden_states.view(B, n_parts, L, -1).reshape(B, n_parts*L, -1) # B, L, H
615
+ return hidden_states
616
+
617
+ def get_trg_logits(self, hidden_states):
618
+ trg_start_logits = self.trg_start_classifier(hidden_states) # B, N, trg_classes
619
+ trg_end_logits = self.trg_end_classifier(hidden_states) # B, N, trg_classes
620
+ return trg_start_logits, trg_end_logits
621
+
622
+ def get_arg_logits(self, hidden_states, trg_repr):
623
+ B, L, H = hidden_states.shape
624
+ _, N, _ = trg_repr.shape
625
+
626
+ hidden_expand = hidden_states.unsqueeze(1).expand(-1, N, -1, -1)
627
+ trg_expand = trg_repr.unsqueeze(2).expand(-1, -1, L, -1)
628
+
629
+ hidden_trg_repr = torch.cat([hidden_expand, trg_expand], dim=-1) # (B, N, L, 2H)
630
+ arg_start_logits = self.arg_start_classifier(hidden_trg_repr) # (B, N, L, C)
631
+ arg_end_logits = self.arg_end_classifier(hidden_trg_repr) # (B, N, L, C)
632
+
633
+ return arg_start_logits, arg_end_logits
634
+
635
+ def forward(self, input_ids, attention_mask, trg_spans=None):
636
+ hidden_states = self.encode(input_ids, attention_mask)
637
+
638
+ trg_start_logits, trg_end_logits = self.get_trg_logits(hidden_states)
639
+
640
+ if trg_spans is None:
641
+ trg_labels = torch.argmax(trg_logits, dim=-1)
642
+ trg_labels = fix_bio_ids_batch(trg_labels)
643
+ trg_spans = extract_trigger_spans_batch_tensor(trg_labels)
644
+
645
+ trg_repr = get_span_repr(hidden_states, trg_spans) # B, N, 4H
646
+
647
+ trg_repr = self.trg_repr_proj(trg_repr) # B, N, H
648
+ arg_start_logits, arg_end_logits = self.get_arg_logits(hidden_states, trg_repr)
649
+
650
+ return trg_start_logits, trg_end_logits, arg_start_logits, arg_end_logits, trg_spans
651
+
652
+ def test():
653
+ model = nn.DataParallel(IEModel(backbone_model_name, 7, 5)).to(device)
654
+ model.eval()
655
+ total_params = sum(p.numel() for p in model.parameters())
656
+ print(f"Total params: {total_params:,}")
657
+
658
+ vocab_size = model.module.encoder.config.vocab_size
659
+ max_len = model.module.encoder.config.max_position_embeddings
660
+
661
+ bz = 32
662
+ i = torch.randint(0, vocab_size, (bz, 5, 10)).to(device)
663
+ a = torch.ones(bz, 5, 10).to(device)
664
+ g = torch.ones(bz, 3, 2, dtype=torch.long).to(device)
665
+
666
+ with torch.no_grad():
667
+ r = model(i, a, g)
668
+
669
+ if type(r) == tuple:
670
+ print([r[i].shape for i in range(len(r))])
671
+ else:
672
+ print(r.shape)
673
+
674
+ test()
675
+
676
+ # %% [code]
677
+ def configure_optimizers(network, optim_params, scheduler_params):
678
+ try:
679
+ optim_params = copy.copy(optim_params)
680
+ scheduler_params = copy.copy(scheduler_params)
681
+
682
+ optim_name = optim_params.pop('name')
683
+ scheduler_name = scheduler_params.pop('name')
684
+
685
+ optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None)
686
+ scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None)
687
+
688
+ if optimizer_cls is None:
689
+ raise ValueError(f"Optimizer '{optim_name}' is not available!")
690
+
691
+ optimizer = optimizer_cls(network.parameters(), **optim_params)
692
+
693
+ scheduler = None
694
+ if scheduler_params and scheduler_cls: # Chỉ tạo scheduler nếu có tham số
695
+ scheduler = scheduler_cls(optimizer, **scheduler_params)
696
+
697
+ return optimizer, scheduler
698
+
699
+ except KeyError as e:
700
+ raise ValueError(f"Missing {e} in config!!")
701
+
702
+ def freeze(self, model):
703
+ model.eval()
704
+ for param in model.parameters():
705
+ param.requires_grad = False
706
+
707
+ def unfreeze(self, model):
708
+ model.train()
709
+ for param in model.parameters():
710
+ param.requires_grad = True
711
+
712
+ def reduce_batch_size(loader, ratio=0.5):
713
+ new_bs = max(1, int(loader.batch_size * ratio))
714
+
715
+ shuffle = isinstance(loader.sampler, RandomSampler)
716
+
717
+ new_loader = DataLoader(
718
+ dataset=loader.dataset,
719
+ batch_size=new_bs,
720
+ shuffle=shuffle,
721
+ sampler=None if shuffle else loader.sampler,
722
+ num_workers=loader.num_workers,
723
+ collate_fn=loader.collate_fn,
724
+ pin_memory=loader.pin_memory,
725
+ drop_last=loader.drop_last,
726
+ timeout=loader.timeout,
727
+ worker_init_fn=loader.worker_init_fn,
728
+ multiprocessing_context=loader.multiprocessing_context,
729
+ generator=loader.generator,
730
+ prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None,
731
+ persistent_workers=loader.persistent_workers,
732
+ pin_memory_device=loader.pin_memory_device
733
+ )
734
+
735
+ return new_loader
736
+
737
+ def list_to_tuple(x):
738
+ if isinstance(x, (list, tuple)):
739
+ return tuple(list_to_tuple(i) for i in x)
740
+ return x
741
+
742
+ def fmt(x):
743
+ if isinstance(x, float):
744
+ return round(x, 5)
745
+ if isinstance(x, dict):
746
+ return {k: fmt(v) for k, v in x.items()}
747
+ if isinstance(x, list):
748
+ return [fmt(v) for v in x]
749
+ return x
750
+
751
+ class ModelEmaV3Proxy(ModelEmaV3):
752
+ def __getattr__(self, name):
753
+ try:
754
+ return super().__getattr__(name)
755
+ except AttributeError:
756
+ return getattr(self.module, name)
757
+
758
+ class DataParallelProxy(nn.DataParallel):
759
+ def __getattr__(self, name):
760
+ try:
761
+ return super().__getattr__(name)
762
+ except AttributeError:
763
+ attr = getattr(self.module, name)
764
+
765
+ if callable(attr):
766
+ def wrapper(*args, **kwargs):
767
+ return self._parallel_apply_method(name, *args, **kwargs)
768
+ return wrapper
769
+
770
+ return attr
771
+
772
+ def _parallel_apply_method(self, method_name, *inputs, **kwargs):
773
+ if not self.device_ids:
774
+ return getattr(self.module, method_name)(*inputs, **kwargs)
775
+
776
+ inputs_scattered, kwargs_scattered = self.scatter(inputs, kwargs, self.device_ids)
777
+
778
+ replicas = self.replicate(self.module, self.device_ids)
779
+
780
+ outputs = self.parallel_apply(
781
+ [getattr(replica, method_name) for replica in replicas],
782
+ inputs_scattered,
783
+ kwargs_scattered
784
+ )
785
+
786
+ return self.gather(outputs, self.output_device)
787
+
788
+ def map_arg_labels(all_arg_labels, trg_spans, pred_spans):
789
+ """
790
+ all_arg_labels: (B, N, L)
791
+ trg_spans: (B, N, 2)
792
+ pred_spans: (B, M, 2)
793
+
794
+ return:
795
+ pred_arg_labels: (B, M, L)
796
+ """
797
+ B, N, L = all_arg_labels.shape
798
+ _, M, _ = pred_spans.shape
799
+
800
+ device = all_arg_labels.device
801
+
802
+ # ===== match (B, M, N) =====
803
+ match = (
804
+ (pred_spans.unsqueeze(2) == trg_spans.unsqueeze(1))
805
+ .all(dim=-1)
806
+ )
807
+
808
+ # ===== index match =====
809
+ match_idx = match.float().argmax(dim=2) # (B, M)
810
+ has_match = match.any(dim=2) # (B, M)
811
+
812
+ # ===== gather =====
813
+ gather_idx = match_idx.unsqueeze(-1).expand(-1, -1, L) # (B, M, L)
814
+
815
+ gathered = torch.gather(
816
+ all_arg_labels,
817
+ dim=1,
818
+ index=gather_idx
819
+ ) # (B, M, L)
820
+
821
+ # ===== build output =====
822
+ # base = 0 nhưng giữ -100
823
+ base = torch.zeros((B, M, L), dtype=torch.long, device=device)
824
+
825
+ # mask vị trí -100 từ source (lấy từ n=0 cũng được vì mask thường giống nhau)
826
+ ignore_mask = (all_arg_labels[:, 0] == -100).unsqueeze(1).expand(-1, M, -1)
827
+ base[ignore_mask] = -100
828
+
829
+ # ===== fill match =====
830
+ pred_arg_labels = torch.where(
831
+ has_match.unsqueeze(-1), # (B, M, 1)
832
+ gathered,
833
+ base
834
+ )
835
+
836
+ return pred_arg_labels.long()
837
+
838
+ def decode_spans(start_labels, end_labels):
839
+ """
840
+ start_labels/end_labels: (L,)
841
+ return: [(s, e, label_id)]
842
+ """
843
+
844
+ L = len(start_labels)
845
+
846
+ used_start = set()
847
+ used_end = set()
848
+
849
+ spans = []
850
+
851
+ for s in range(L):
852
+
853
+ s_label = start_labels[s]
854
+
855
+ if s_label == 0:
856
+ continue
857
+
858
+ if s in used_start:
859
+ continue
860
+
861
+ nearest_e = None
862
+
863
+ for e in range(s, L):
864
+
865
+ if e in used_end:
866
+ continue
867
+
868
+ e_label = end_labels[e]
869
+
870
+ if e_label == s_label:
871
+ nearest_e = e
872
+ break
873
+
874
+ if nearest_e is None:
875
+ continue
876
+
877
+ used_start.add(s)
878
+ used_end.add(nearest_e)
879
+
880
+ spans.append((s, nearest_e, s_label))
881
+
882
+ return spans
883
+
884
+ def decode_spans_batch(start_labels, end_labels):
885
+ """
886
+ Args:
887
+ start_labels: (B, L)
888
+ end_labels: (B, L)
889
+
890
+ Returns:
891
+ spans_tensor: (B, N, 2)
892
+
893
+ N = số span lớn nhất trong batch
894
+ padding = (0, 0)
895
+ """
896
+
897
+ B, L = start_labels.shape
898
+
899
+ all_spans = []
900
+ max_n = 0
901
+
902
+ for bidx in range(B):
903
+
904
+ used_start = set()
905
+ used_end = set()
906
+
907
+ spans = []
908
+
909
+ for s in range(L):
910
+
911
+ s_label = start_labels[bidx, s].item()
912
+
913
+ if s_label == 0:
914
+ continue
915
+
916
+ if s in used_start:
917
+ continue
918
+
919
+ nearest_e = None
920
+
921
+ for e in range(s, L):
922
+
923
+ if e in used_end:
924
+ continue
925
+
926
+ e_label = end_labels[bidx, e].item()
927
+
928
+ if e_label == s_label:
929
+ nearest_e = e
930
+ break
931
+
932
+ if nearest_e is None:
933
+ continue
934
+
935
+ used_start.add(s)
936
+ used_end.add(nearest_e)
937
+
938
+ spans.append((s, nearest_e))
939
+
940
+ all_spans.append(spans)
941
+ max_n = max(max_n, len(spans))
942
+
943
+ # ===== padding =====
944
+ spans_tensor = torch.zeros(
945
+ (B, max_n, 2),
946
+ dtype=torch.long,
947
+ device=start_labels.device
948
+ )
949
+
950
+ for bidx, spans in enumerate(all_spans):
951
+ for n, (s, e) in enumerate(spans):
952
+ spans_tensor[bidx, n, 0] = s
953
+ spans_tensor[bidx, n, 1] = e
954
+
955
+ return spans_tensor
956
+
957
+ def extract_arguments(
958
+ input_ids,
959
+ trg_start_logits,
960
+ trg_end_logits,
961
+ arg_start_logits,
962
+ arg_end_logits,
963
+ pred_trg_spans,
964
+ id2label
965
+ ):
966
+ """
967
+ input_ids: (B, L)
968
+
969
+ trg_start_logits: (B, L, C_trg)
970
+ trg_end_logits: (B, L, C_trg)
971
+
972
+ arg_start_logits: (B, N, L, C_arg)
973
+ arg_end_logits: (B, N, L, C_arg)
974
+
975
+ pred_trg_spans: (B, N, 2)
976
+
977
+ id2label = {
978
+ 'Trg': {id: label},
979
+ 'Arg': {id: label}
980
+ }
981
+ """
982
+
983
+ B, L = input_ids.shape
984
+
985
+ # ===== decode trigger =====
986
+ trg_start_ids = torch.argmax(trg_start_logits, dim=-1) # (B, L)
987
+ trg_end_ids = torch.argmax(trg_end_logits, dim=-1) # (B, L)
988
+
989
+ # ===== extract trigger spans =====
990
+ trg_spans = []
991
+
992
+ for bidx in range(B):
993
+ spans = decode_spans(
994
+ trg_start_ids[bidx].tolist(),
995
+ trg_end_ids[bidx].tolist()
996
+ )
997
+ trg_spans.append(spans)
998
+
999
+ results = []
1000
+
1001
+ for bidx in range(B):
1002
+
1003
+ # map span -> label
1004
+ span2label = {
1005
+ (s, e): id2label['Trg'][t_id]
1006
+ for (s, e, t_id) in trg_spans[bidx]
1007
+ }
1008
+
1009
+ for n in range(pred_trg_spans.shape[1]):
1010
+
1011
+ s_trg = pred_trg_spans[bidx, n, 0].item()
1012
+ e_trg = pred_trg_spans[bidx, n, 1].item()
1013
+
1014
+ # skip padding
1015
+ if s_trg == 0 and e_trg == 0:
1016
+ continue
1017
+
1018
+ if (s_trg, e_trg) not in span2label:
1019
+ continue
1020
+
1021
+ trg_label = span2label[(s_trg, e_trg)]
1022
+
1023
+ trg_tokens = input_ids[
1024
+ bidx,
1025
+ s_trg:e_trg + 1
1026
+ ].tolist()
1027
+
1028
+ # ===== argument =====
1029
+ arg_start_ids = torch.argmax(
1030
+ arg_start_logits[bidx, n],
1031
+ dim=-1
1032
+ ).tolist()
1033
+
1034
+ arg_end_ids = torch.argmax(
1035
+ arg_end_logits[bidx, n],
1036
+ dim=-1
1037
+ ).tolist()
1038
+
1039
+ arg_spans = decode_spans(
1040
+ arg_start_ids,
1041
+ arg_end_ids
1042
+ )
1043
+
1044
+ for s_arg, e_arg, arg_label_id in arg_spans:
1045
+
1046
+ arg_label = id2label['Arg'][arg_label_id]
1047
+
1048
+ arg_tokens = input_ids[
1049
+ bidx,
1050
+ s_arg:e_arg + 1
1051
+ ].tolist()
1052
+
1053
+ results.append((
1054
+ bidx,
1055
+ (tuple(trg_tokens), trg_label),
1056
+ (tuple(arg_tokens), arg_label)
1057
+ ))
1058
+
1059
+ return results
1060
+
1061
+ class Trainer:
1062
+ def __init__(
1063
+ self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0,
1064
+ logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu',
1065
+ schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True
1066
+ ):
1067
+ self.ema_net = None
1068
+
1069
+ self.training_time = self._time_str_to_seconds(training_time)
1070
+ self.mode = eval_mode
1071
+ self.topk = topk
1072
+ self.device = device
1073
+ self.logging = logging if logging < epochs else 1
1074
+ self.logging_file = logging_file
1075
+ self.checkpoints_dir = checkpoints_dir
1076
+ self.early_stopping = early_stopping
1077
+ self.eval_from_ratio = eval_from_ratio
1078
+ self.eval_every = eval_every
1079
+ self.save_name = save_name
1080
+ self.save_best = save_best
1081
+ self.save_last = save_last
1082
+ self.return_best = return_best
1083
+ self.return_last = return_last
1084
+ self.max_grad_norm = max_grad_norm
1085
+ self.schedule_in_step = schedule_in_step
1086
+ self.use_ema = use_ema
1087
+ self.ema_from_ratio = ema_from_ratio
1088
+ self.ema_decay = ema_decay
1089
+
1090
+ self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]]
1091
+ self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0)
1092
+
1093
+ def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, id2label=None):
1094
+ if eval_fn is None:
1095
+ if self.mode == "max":
1096
+ eval_fn = lambda *x: -loss_fn(*x)
1097
+ else:
1098
+ eval_fn = lambda *x: loss_fn(*x)
1099
+
1100
+ if torch.cuda.device_count() > 1:
1101
+ network = DataParallelProxy(network)
1102
+ network = network.to(self.device)
1103
+
1104
+ if not start_training_time:
1105
+ start_training_time = time.time()
1106
+
1107
+ start_ema = int(epochs * self.ema_from_ratio)
1108
+ start_eval = int(epochs * self.eval_from_ratio)
1109
+
1110
+ if val_loader is None:
1111
+ print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!')
1112
+ else:
1113
+ model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc'
1114
+ start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else ''
1115
+ print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str)
1116
+
1117
+ training_log = {}
1118
+ for epoch in range(start_epoch, epochs+start_epoch):
1119
+ if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema:
1120
+ self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
1121
+
1122
+ try:
1123
+ teaching_rate = math.cos(math.pi / 2 * epoch / epochs)
1124
+ train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, teaching_rate)
1125
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
1126
+ logging_dict.update(train_loss_epoch_dict)
1127
+
1128
+ if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0:
1129
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1130
+
1131
+ val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, id2label)
1132
+ update = self._update_best_network(eval_net, val_score, epoch)
1133
+ logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update})
1134
+ logging_dict.update(val_score_dict)
1135
+ if not self.schedule_in_step and scheduler:
1136
+ scheduler.step()
1137
+
1138
+ except RuntimeError as e:
1139
+ if "out of memory" in str(e).lower():
1140
+ print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...")
1141
+ torch.cuda.empty_cache()
1142
+ gc.collect()
1143
+ if torch.cuda.is_available():
1144
+ torch.cuda.synchronize()
1145
+ print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared")
1146
+
1147
+ train_loader = reduce_batch_size(train_loader, ratio=0.5)
1148
+ if val_loader is not None:
1149
+ val_loader = reduce_batch_size(val_loader, ratio=0.5)
1150
+
1151
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')}
1152
+ else:
1153
+ raise
1154
+
1155
+ training_log[epoch] = logging_dict
1156
+ if self.is_early_stopping(epoch):
1157
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...')
1158
+ break
1159
+ if self.logging:
1160
+ if epoch % self.logging == 0:
1161
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict))
1162
+ else:
1163
+ print(f'{epoch}...', end=' ')
1164
+
1165
+ if self._at_time_limit(start_training_time):
1166
+ print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}')
1167
+ break
1168
+
1169
+ if self.logging_file:
1170
+ os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True)
1171
+ with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f:
1172
+ f.write(json.dumps(training_log))
1173
+
1174
+ if self.use_ema and self.ema_net is not None:
1175
+ self._save_state_dict(self.ema_net.module)
1176
+ else:
1177
+ self._save_state_dict(network)
1178
+ print(f'[Trainer CallBack] 📢 Kết thúc training.\n')
1179
+
1180
+ best_model, last_model = None, None
1181
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1182
+ if self.return_best :
1183
+ best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict()
1184
+ best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()}
1185
+ if self.return_last:
1186
+ last_model = eval_net.state_dict()
1187
+ last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()}
1188
+
1189
+ del network
1190
+ torch.cuda.empty_cache()
1191
+ gc.collect()
1192
+ return training_log, best_model, last_model
1193
+
1194
+ def _time_str_to_seconds(self, time_str):
1195
+ days, hours, minutes, seconds = map(int, time_str.split(":"))
1196
+ return days * 86400 + hours * 3600 + minutes * 60 + seconds
1197
+
1198
+ def _update_best_network(self, network, val_score, epoch):
1199
+ topk = max(1, self.topk)
1200
+ self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}])
1201
+ self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk]
1202
+ if val_score in [x[0] for x in self.best_stage]:
1203
+ return True
1204
+ return False
1205
+
1206
+ def is_early_stopping(self, epoch):
1207
+ if self.best_stage[0][1] is None:
1208
+ return False
1209
+ if not self.early_stopping:
1210
+ return False
1211
+ return epoch - self.best_stage[0][1] >= self.early_stopping
1212
+
1213
+ def _at_time_limit(self, start_training_time):
1214
+ return time.time() - start_training_time >= self.training_time
1215
+
1216
+ def _save_state_dict(self, network):
1217
+ if self.topk <= 0:
1218
+ return
1219
+
1220
+ if self.save_best:
1221
+ for r in range(self.topk):
1222
+ os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True)
1223
+
1224
+ for rank, (score, epoch, state_dict) in enumerate(self.best_stage):
1225
+ if state_dict is None:
1226
+ continue
1227
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()}
1228
+ torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth')
1229
+ if self.save_last:
1230
+ os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True)
1231
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()}
1232
+ torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth')
1233
+
1234
+ def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn, teaching_rate):
1235
+ network.train()
1236
+ total_loss = 0
1237
+ total_loss_dict = {}
1238
+ for batch_idx, batch in enumerate(train_loader):
1239
+ optimizer.zero_grad()
1240
+ with torch.autocast(device_type=self.device, dtype=torch.float16):
1241
+ loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn, teaching_rate)
1242
+
1243
+ for k, v in loss_dict.items():
1244
+ t = total_loss_dict.get(k, 0)
1245
+ total_loss_dict[k] = t + v
1246
+ self.grad_scaler.scale(loss).backward()
1247
+ self.grad_scaler.unscale_(optimizer)
1248
+ grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm)
1249
+ # print(grad_norm) # Bỏ cmt dòng này để biết nên chọn max_grad_norm bằng bao nhiêu...
1250
+ self.grad_scaler.step(optimizer)
1251
+ self.grad_scaler.update()
1252
+ if self.schedule_in_step and scheduler:
1253
+ scheduler.step()
1254
+ if self.use_ema and self.ema_net is not None:
1255
+ self.ema_net.update(network)
1256
+ total_loss += loss
1257
+ return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()}
1258
+
1259
+ def _eval_epoch(self, network, val_loader, eval_fn, id2label):
1260
+ network.eval()
1261
+ total_score = 0.0
1262
+ total_score_dict = {}
1263
+ object_lists = None # sẽ init sau
1264
+
1265
+ with torch.no_grad():
1266
+ for batch_idx, batch in enumerate(val_loader):
1267
+ score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, id2label)
1268
+ total_score += score
1269
+
1270
+ for k, v in score_dict.items():
1271
+ t = total_score_dict.get(k, 0)
1272
+ total_score_dict[k] = t + v
1273
+
1274
+ if objects:
1275
+ if object_lists is None:
1276
+ object_lists = [[] for _ in range(len(objects))]
1277
+
1278
+ for i, obj in enumerate(objects):
1279
+ object_lists[i].append(obj.detach())
1280
+
1281
+ if object_lists is not None:
1282
+ object_arrays = [
1283
+ torch.concat(obj_list, dim=0).cpu().numpy()
1284
+ for obj_list in object_lists
1285
+ ]
1286
+ else:
1287
+ object_arrays = []
1288
+
1289
+ return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays
1290
+
1291
+ def _cal_loss(self, network, batch, batch_idx, loss_fn, teaching_rate):
1292
+ # Bạn cần override _cal_loss để tính loss
1293
+ input_ids = batch['input_ids'].to(self.device)
1294
+ attention_mask = batch['attention_mask'].to(self.device)
1295
+ trg_spans = batch['trg_spans'].to(self.device) # B, M, 2
1296
+ trg_start_labels = batch['trg_start_labels'].to(self.device) # B, L
1297
+ trg_end_labels = batch['trg_end_labels'].to(self.device) # B, L
1298
+ all_arg_start_labels = batch['all_arg_start_labels'].to(self.device) # B, M, L
1299
+ all_arg_end_labels = batch['all_arg_end_labels'].to(self.device) # B, M, L
1300
+
1301
+ hidden_states = network.encode(input_ids, attention_mask)
1302
+ trg_start_logits, trg_end_logits = network.get_trg_logits(hidden_states)
1303
+
1304
+ choice = random.random()
1305
+ if choice < teaching_rate:
1306
+ pred_trg_spans = trg_spans
1307
+ else:
1308
+ trg_start_ids = torch.argmax(trg_start_logits, dim=-1) # (B, L)
1309
+ trg_end_ids = torch.argmax(trg_end_logits, dim=-1) # (B, L)
1310
+ pred_trg_spans = decode_spans_batch(trg_start_ids, trg_end_ids)
1311
+
1312
+ trg_repr = get_span_repr(hidden_states, pred_trg_spans) # B, N, 4H
1313
+
1314
+ trg_repr = network.trg_repr_proj(trg_repr) # B, N, H
1315
+ arg_start_logits, arg_end_logits = network.get_arg_logits(hidden_states, trg_repr)
1316
+
1317
+ pred_arg_start_labels = map_arg_labels(all_arg_start_labels, trg_spans, pred_trg_spans)
1318
+ pred_arg_end_labels = map_arg_labels(all_arg_end_labels, trg_spans, pred_trg_spans)
1319
+
1320
+ loss_dict = loss_fn(
1321
+ trg_start_logits, trg_start_labels,
1322
+ trg_end_logits, trg_end_labels,
1323
+ arg_start_logits, pred_arg_start_labels,
1324
+ arg_end_logits, pred_arg_end_labels,
1325
+ )
1326
+ return loss_dict['total'], loss_dict
1327
+
1328
+ def _cal_val_score(self, network, batch, batch_idx, eval_fn, id2label):
1329
+ # Bạn cần override _cal_val_score để tính val score, list bên cạnh là để trả về y hay pred gì đó (nếu cần)
1330
+ input_ids = batch['input_ids'].to(self.device)
1331
+ attention_mask = batch['attention_mask'].to(self.device)
1332
+ gold_events = batch['gold_events']
1333
+
1334
+ B, _, _ = input_ids.shape
1335
+
1336
+ hidden_states = network.encode(input_ids, attention_mask)
1337
+ trg_start_logits, trg_end_logits = network.get_trg_logits(hidden_states)
1338
+
1339
+ trg_start_ids = torch.argmax(trg_start_logits, dim=-1) # (B, L)
1340
+ trg_end_ids = torch.argmax(trg_end_logits, dim=-1) # (B, L)
1341
+ pred_trg_spans = decode_spans_batch(trg_start_ids, trg_end_ids)
1342
+ trg_repr = get_span_repr(hidden_states, pred_trg_spans) # B, N, 4H
1343
+
1344
+ trg_repr = network.trg_repr_proj(trg_repr) # B, N, H
1345
+ arg_start_logits, arg_end_logits = network.get_arg_logits(hidden_states, trg_repr)
1346
+
1347
+ pred_ids = extract_arguments(input_ids.reshape(B, -1), trg_start_logits, trg_end_logits, arg_start_logits, arg_end_logits, pred_trg_spans, id2label)
1348
+ pred_ids = list_to_tuple(pred_ids)
1349
+
1350
+ gold_ids = list_to_tuple(gold_events)
1351
+
1352
+ score_dict = eval_fn(pred_ids, gold_ids)
1353
+ return score_dict['f1'], score_dict, []
1354
+
1355
+ # %% [code]
1356
+ class PhoBERTSpanAligner:
1357
+ def __init__(self, tokenizer, max_len):
1358
+ self.tokenizer = tokenizer
1359
+ self.max_len = max_len
1360
+
1361
+ # ===== 1. Extract discontinuous spans =====
1362
+ def extract_spans(self, sample):
1363
+ trigger_spans, arg_spans = [], []
1364
+
1365
+ for event in sample["events"]:
1366
+ trigger_type = event["label"]
1367
+ spans = [tuple(event["offset"])]
1368
+ trigger_spans.append({
1369
+ "spans": spans,
1370
+ "label": trigger_type
1371
+ })
1372
+ event_arg_spans = []
1373
+ for arg in event['arguments']:
1374
+ arg_type = arg["role"]
1375
+ spans = [tuple(arg["offset"])]
1376
+ event_arg_spans.append({
1377
+ "spans": spans,
1378
+ "label": arg_type
1379
+ })
1380
+ arg_spans.append(event_arg_spans)
1381
+
1382
+ return trigger_spans, arg_spans
1383
+
1384
+ # ===== 2. Word offsets =====
1385
+ def build_word_offsets(self, text, words):
1386
+ offsets = []
1387
+ pointer = 0
1388
+
1389
+ for word in words:
1390
+ start = text.find(word, pointer)
1391
+ end = start + len(word)
1392
+ offsets.append((start, end))
1393
+ pointer = end
1394
+
1395
+ return offsets
1396
+
1397
+ # ===== 3. Char → word =====
1398
+ def char_span_to_word_span(self, word_offsets, start, end):
1399
+ start_word = None
1400
+ end_word = None
1401
+
1402
+ for i, (w_start, w_end) in enumerate(word_offsets):
1403
+ if w_start <= start < w_end:
1404
+ start_word = i
1405
+ if w_start < end <= w_end:
1406
+ end_word = i
1407
+
1408
+ return start_word, end_word
1409
+
1410
+ # ===== 4. Word → subword =====
1411
+ def word_to_subword_map(self, words):
1412
+ mapping = []
1413
+ subword_index = 1 # <s>
1414
+
1415
+ for word in words:
1416
+ sub_tokens = self.tokenizer.tokenize(word)
1417
+ start = subword_index
1418
+ end = subword_index + len(sub_tokens) - 1
1419
+ mapping.append((start, end))
1420
+ subword_index += len(sub_tokens)
1421
+
1422
+ return mapping
1423
+
1424
+ # ===== 5. Span → subword =====
1425
+ def span_to_subword(self, word_offsets, word_subword_map, spans):
1426
+ sub_spans = []
1427
+
1428
+ for span_start, span_end in spans:
1429
+ w_start, w_end = self.char_span_to_word_span(
1430
+ word_offsets, span_start, span_end
1431
+ )
1432
+ if w_start is None or w_end is None:
1433
+ continue
1434
+
1435
+ sub_start = word_subword_map[w_start][0]
1436
+ sub_end = word_subword_map[w_end][1]
1437
+ sub_spans.append((sub_start, sub_end))
1438
+
1439
+ return sub_spans
1440
+
1441
+ def extract_valid_spans(self, sub_spans):
1442
+ valid_spans = []
1443
+ for s, e in sub_spans:
1444
+ if s < 0 or e < 0 or s >= self.max_len or e >= self.max_len or s > e:
1445
+ continue
1446
+ valid_spans.append((s, e))
1447
+ return valid_spans
1448
+
1449
+ def encode(self, sample):
1450
+ text = sample["text"]
1451
+ triggers, arguments = self.extract_spans(sample)
1452
+
1453
+ # ===== 1. Word tokenize =====
1454
+ words = word_tokenize(text)
1455
+ sentence = " ".join(words)
1456
+
1457
+ # ===== 2. Mapping =====
1458
+ word_offsets = self.build_word_offsets(text, words)
1459
+ word_subword_map = self.word_to_subword_map(words)
1460
+
1461
+ # ===== 3. Tokenize FULL =====
1462
+ encoding = self.tokenizer(
1463
+ sentence,
1464
+ max_length=self.max_len,
1465
+ truncation=True,
1466
+ padding="max_length",
1467
+ return_tensors="pt"
1468
+ )
1469
+ input_ids = encoding["input_ids"][0]
1470
+ attention_mask = encoding["attention_mask"][0]
1471
+
1472
+ # ===== 5. Convert spans =====
1473
+ triggers_gold_spans = []
1474
+ arguments_gold_spans = []
1475
+
1476
+ for trg, args in zip(triggers, arguments):
1477
+ label = trg["label"]
1478
+
1479
+ sub_spans = self.span_to_subword(
1480
+ word_offsets,
1481
+ word_subword_map,
1482
+ trg["spans"]
1483
+ )
1484
+ valid_spans = self.extract_valid_spans(sub_spans)
1485
+ if len(valid_spans) == 0:
1486
+ continue
1487
+ triggers_gold_spans.append((tuple(valid_spans), label))
1488
+
1489
+ trg_args_gold_spans = []
1490
+ for arg in args:
1491
+ label = arg["label"]
1492
+
1493
+ sub_spans = self.span_to_subword(
1494
+ word_offsets,
1495
+ word_subword_map,
1496
+ arg["spans"]
1497
+ )
1498
+ valid_spans = self.extract_valid_spans(sub_spans)
1499
+ if len(valid_spans) == 0:
1500
+ continue
1501
+ trg_args_gold_spans.append((tuple(valid_spans), label))
1502
+ arguments_gold_spans.append(tuple(trg_args_gold_spans))
1503
+
1504
+ return {
1505
+ "input_ids": input_ids,
1506
+ "attention_mask": attention_mask,
1507
+ "triggers_gold_spans": triggers_gold_spans,
1508
+ "arguments_gold_spans": arguments_gold_spans,
1509
+ }
1510
+
1511
+ def generate_candidate_spans(seq_len, max_span_len):
1512
+ spans = []
1513
+ for i in range(1, seq_len+1):
1514
+ for j in range(i, min(i+max_span_len, seq_len+1)):
1515
+ spans.append((i, j))
1516
+ return spans
1517
+
1518
+ class KLTNDataset(Dataset):
1519
+ def __init__(self, all_data, using_idxes, label2id, tokenizer, max_len, max_n_parts):
1520
+ super().__init__()
1521
+ self.tokenizer = tokenizer
1522
+ self.aligner = PhoBERTSpanAligner(tokenizer, max_len*max_n_parts)
1523
+ self.all_data = all_data
1524
+ self.using_idxes = using_idxes
1525
+ self.label2id = label2id
1526
+ self.max_len = max_len
1527
+ self.max_n_parts = max_n_parts
1528
+
1529
+ def __len__(self):
1530
+ return len(self.using_idxes)
1531
+
1532
+ def __getitem__(self, idx):
1533
+ ridx = self.using_idxes[idx]
1534
+ sample = self.all_data[ridx]
1535
+ result = self.aligner.encode(sample)
1536
+
1537
+ input_ids = result["input_ids"].squeeze(0)
1538
+ attention_mask = result["attention_mask"].squeeze(0)
1539
+ triggers_gold_spans = result["triggers_gold_spans"]
1540
+ arguments_gold_spans = result["arguments_gold_spans"]
1541
+
1542
+ # Get event label
1543
+ all_trg_spans = torch.tensor([list(trg_spans[0]) for trg_spans, _ in triggers_gold_spans], dtype=torch.long) if triggers_gold_spans else torch.empty(0, 2, dtype=torch.long)
1544
+ gold_events = []
1545
+ trg_start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
1546
+ trg_end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
1547
+ all_arg_start_labels, all_arg_end_labels = [], []
1548
+ for (trg_spans, trg_label), args in zip(triggers_gold_spans, arguments_gold_spans):
1549
+ s, e = trg_spans[0]
1550
+
1551
+ trg_start_labels[s] = self.label2id['Trg'][f'{trg_label}']
1552
+ trg_end_labels[e] = self.label2id['Trg'][f'{trg_label}']
1553
+
1554
+ event = [(tuple(input_ids[s:e+1].tolist()), trg_label)]
1555
+
1556
+ arg_start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
1557
+ arg_end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
1558
+ for arg_spans, arg_label in args:
1559
+ s, e = arg_spans[0]
1560
+
1561
+ arg_start_labels[s] = self.label2id['Arg'][f'{arg_label}']
1562
+ arg_end_labels[e] = self.label2id['Arg'][f'{arg_label}']
1563
+
1564
+ event.append((tuple(input_ids[s:e+1].tolist()), arg_label))
1565
+ all_arg_start_labels.append(arg_start_labels)
1566
+ all_arg_end_labels.append(arg_end_labels)
1567
+
1568
+ gold_events.append(event)
1569
+
1570
+ input_ids = input_ids.reshape(self.max_n_parts, self.max_len)
1571
+ attention_mask = attention_mask.reshape(self.max_n_parts, self.max_len)
1572
+
1573
+ n_valid_parts = math.ceil(attention_mask.sum().item() / self.max_len)
1574
+ input_ids = input_ids[:n_valid_parts]
1575
+ attention_mask = attention_mask[:n_valid_parts]
1576
+ trg_start_labels = trg_start_labels[:n_valid_parts*self.max_len]
1577
+ trg_end_labels = trg_end_labels[:n_valid_parts*self.max_len]
1578
+ all_arg_start_labels = torch.stack([arg_labels[:n_valid_parts*self.max_len] for arg_labels in all_arg_start_labels], dim=0) if all_arg_start_labels else torch.empty(0, n_valid_parts*self.max_len)
1579
+ all_arg_end_labels = torch.stack([arg_labels[:n_valid_parts*self.max_len] for arg_labels in all_arg_end_labels], dim=0) if all_arg_end_labels else torch.empty(0, n_valid_parts*self.max_len)
1580
+
1581
+ return {
1582
+ "input_ids": input_ids,
1583
+ "attention_mask": attention_mask,
1584
+ "trg_spans": all_trg_spans,
1585
+ "trg_start_labels": trg_start_labels,
1586
+ "trg_end_labels": trg_end_labels,
1587
+ "all_arg_start_labels": all_arg_start_labels,
1588
+ "all_arg_end_labels": all_arg_end_labels,
1589
+ "gold_events": gold_events,
1590
+ }
1591
+
1592
+ def _pad_batch(tensor_list, pad_value=0):
1593
+ """
1594
+ tensor_list: list of tensors
1595
+ mỗi tensor shape: (Nk, n_parts_i, max_len_i)
1596
+
1597
+ return:
1598
+ padded tensor shape: (B, max_Nk, max_n_parts, max_len)
1599
+ """
1600
+
1601
+ # lấy max toàn batch
1602
+ max_Nk = max(t.size(0) for t in tensor_list)
1603
+ max_n_parts = max(t.size(1) for t in tensor_list)
1604
+ max_len = max(t.size(2) for t in tensor_list)
1605
+
1606
+ padded = []
1607
+
1608
+ for t in tensor_list:
1609
+ Nk, n_parts_i, max_len_i = t.shape
1610
+
1611
+ # pad chiều n_parts và max_len trước
1612
+ if n_parts_i < max_n_parts or max_len_i < max_len:
1613
+ new_t = t.new_full(
1614
+ (Nk, max_n_parts, max_len),
1615
+ pad_value
1616
+ )
1617
+ new_t[:, :n_parts_i, :max_len_i] = t
1618
+ t = new_t
1619
+
1620
+ # pad chiều Nk
1621
+ if Nk < max_Nk:
1622
+ pad_tensor = t.new_full(
1623
+ (max_Nk - Nk, max_n_parts, max_len),
1624
+ pad_value
1625
+ )
1626
+ t = torch.cat([t, pad_tensor], dim=0)
1627
+
1628
+ padded.append(t)
1629
+
1630
+ return torch.stack(padded) # (B, max_Nk, max_n_parts, max_len)
1631
+
1632
+ def collate_fn(batch):
1633
+ gold_events = []
1634
+ for bidx, b in enumerate(batch):
1635
+ for event in b['gold_events']:
1636
+ trg = event[0]
1637
+ if len(event) > 1:
1638
+ for arg in event[1:]:
1639
+ gold_events.append([bidx, trg, arg])
1640
+ else:
1641
+ gold_events.append([bidx, trg, (tuple([]), 0)])
1642
+
1643
+ input_ids = [b["input_ids"].unsqueeze(-1) for b in batch]
1644
+ attention_mask = [b["attention_mask"].unsqueeze(-1) for b in batch]
1645
+ trg_spans = [b["trg_spans"].unsqueeze(-1) for b in batch]
1646
+ trg_start_labels = [b["trg_start_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1647
+ trg_end_labels = [b["trg_end_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1648
+ all_arg_start_labels = [b["all_arg_start_labels"].unsqueeze(-1) for b in batch]
1649
+ all_arg_end_labels = [b["all_arg_end_labels"].unsqueeze(-1) for b in batch]
1650
+
1651
+ # pad theo Nk
1652
+ input_ids = _pad_batch(input_ids, pad_value=0).squeeze(-1)
1653
+ attention_mask = _pad_batch(attention_mask, pad_value=0).squeeze(-1)
1654
+ trg_spans = _pad_batch(trg_spans, pad_value=0).squeeze(-1)
1655
+ trg_start_labels = _pad_batch(trg_start_labels, pad_value=-100).squeeze(-1).squeeze(-1)
1656
+ trg_end_labels = _pad_batch(trg_end_labels, pad_value=-100).squeeze(-1).squeeze(-1)
1657
+ all_arg_start_labels = _pad_batch(all_arg_start_labels, pad_value=-100).squeeze(-1)
1658
+ all_arg_end_labels = _pad_batch(all_arg_end_labels, pad_value=-100).squeeze(-1)
1659
+
1660
+ return {
1661
+ "input_ids": input_ids,
1662
+ "attention_mask": attention_mask,
1663
+ "trg_spans": trg_spans,
1664
+ "trg_start_labels": trg_start_labels,
1665
+ "trg_end_labels": trg_end_labels,
1666
+ "all_arg_start_labels": all_arg_start_labels,
1667
+ "all_arg_end_labels": all_arg_end_labels,
1668
+ "gold_events": gold_events,
1669
+ }
1670
+
1671
+ # %% [code]
1672
+ def shift_bidx(spans, batch_idx):
1673
+ shifted = []
1674
+ for bidx, trg, arg in spans:
1675
+ new_bidx = bidx + batch_idx * batch_size
1676
+ shifted.append((new_bidx, trg, arg))
1677
+ return shifted
1678
+
1679
+ def refactor_events(events, save_dict):
1680
+ trg_i, trg_c, arg_i, arg_c, soft, strict_dict = [], [], [], [], [], {}
1681
+ for bidx, (trg_ids, trg_lb), (arg_k_ids, arg_k_lb) in events:
1682
+ if (bidx, trg_ids) not in trg_i:
1683
+ trg_i.append((bidx, trg_ids))
1684
+
1685
+ if (bidx, (trg_ids, trg_lb)) not in trg_c:
1686
+ trg_c.append((bidx, (trg_ids, trg_lb)))
1687
+
1688
+ if (bidx, trg_ids, arg_k_ids) not in arg_i:
1689
+ arg_i.append((bidx, trg_ids, arg_k_ids))
1690
+
1691
+ if (bidx, trg_ids, (arg_k_ids, arg_k_lb)) not in arg_c:
1692
+ arg_c.append((bidx, trg_ids, (arg_k_ids, arg_k_lb)))
1693
+
1694
+ if (bidx, (trg_ids, trg_lb), (arg_k_ids, arg_k_lb)) not in soft:
1695
+ soft.append((bidx, (trg_ids, trg_lb), (arg_k_ids, arg_k_lb)))
1696
+
1697
+ if bidx not in strict_dict:
1698
+ strict_dict[bidx] = {}
1699
+ if (trg_ids, trg_lb) not in strict_dict[bidx]:
1700
+ strict_dict[bidx][(trg_ids, trg_lb)] = []
1701
+ strict_dict[bidx][(trg_ids, trg_lb)].append((arg_k_ids, arg_k_lb))
1702
+
1703
+ strict = []
1704
+ for bidx, trg_dict in strict_dict.items():
1705
+ for trg, args in trg_dict.items():
1706
+ strict.append((bidx, trg, frozenset(args)))
1707
+
1708
+ save_dict['Trg-I'].extend(trg_i)
1709
+ save_dict['Trg-C'].extend(trg_c)
1710
+ save_dict['Arg-I'].extend(arg_i)
1711
+ save_dict['Arg-C'].extend(arg_c)
1712
+ save_dict['Soft-Event'].extend(soft)
1713
+ save_dict['Strict-Event'].extend(strict)
1714
+
1715
+ def test(network, state_dicts, test_loader, eval_fn, analyzer, device, id2label, tokenizer):
1716
+ if torch.cuda.device_count() > 1:
1717
+ network = DataParallelProxy(network)
1718
+ network = network.to(device)
1719
+ network.eval()
1720
+
1721
+ eval_types = ['Trg-I', 'Trg-C', 'Arg-I', 'Arg-C', 'Soft-Event', 'Strict-Event']
1722
+
1723
+ all_pred = {eval_type: [] for eval_type in eval_types}
1724
+ all_gold = {eval_type: [] for eval_type in eval_types}
1725
+
1726
+ list_input_ids = []
1727
+
1728
+ with torch.no_grad():
1729
+ for batch_idx, batch in enumerate(test_loader):
1730
+ input_ids = batch['input_ids'].to(device)
1731
+ attention_mask = batch['attention_mask'].to(device)
1732
+ gold_events = batch['gold_events']
1733
+
1734
+ B, _, _ = input_ids.shape
1735
+ list_input_ids.extend(input_ids.reshape(B, -1).tolist())
1736
+
1737
+ list_trg_start_logits = []
1738
+ list_trg_end_logits = []
1739
+ list_hidden_states = []
1740
+ list_arg_start_logits = []
1741
+ list_arg_end_logits = []
1742
+
1743
+ for sd in state_dicts:
1744
+ if torch.cuda.device_count() > 1:
1745
+ network.module.load_state_dict(sd)
1746
+ else:
1747
+ network.load_state_dict(sd)
1748
+
1749
+ hidden_states = network.encode(input_ids, attention_mask)
1750
+ trg_start_logits, trg_end_logits = network.get_trg_logits(hidden_states)
1751
+ list_trg_start_logits.append(trg_start_logits)
1752
+ list_trg_end_logits.append(trg_end_logits)
1753
+ list_hidden_states.append(hidden_states)
1754
+
1755
+ ensemble_trg_start_logits = torch.stack(list_trg_start_logits, dim=0).mean(dim=0)
1756
+ ensemble_trg_end_logits = torch.stack(list_trg_end_logits, dim=0).mean(dim=0)
1757
+ trg_start_ids = torch.argmax(ensemble_trg_start_logits, dim=-1) # (B, L)
1758
+ trg_end_ids = torch.argmax(ensemble_trg_end_logits, dim=-1) # (B, L)
1759
+ pred_trg_spans = decode_spans_batch(trg_start_ids, trg_end_ids)
1760
+
1761
+ for sd, hidden_states in zip(state_dicts, list_hidden_states):
1762
+ if torch.cuda.device_count() > 1:
1763
+ network.module.load_state_dict(sd)
1764
+ else:
1765
+ network.load_state_dict(sd)
1766
+
1767
+ trg_repr = get_span_repr(hidden_states, pred_trg_spans) # B, N, 4H
1768
+ trg_repr = network.trg_repr_proj(trg_repr) # B, N, H
1769
+ arg_start_logits, arg_end_logits = network.get_arg_logits(hidden_states, trg_repr)
1770
+
1771
+ list_arg_start_logits.append(arg_start_logits)
1772
+ list_arg_end_logits.append(arg_end_logits)
1773
+
1774
+ ensemble_arg_start_logits = torch.stack(list_arg_start_logits, dim=0).mean(dim=0)
1775
+ ensemble_arg_end_logits = torch.stack(list_arg_end_logits, dim=0).mean(dim=0)
1776
+
1777
+ pred_events = extract_arguments(
1778
+ input_ids.reshape(B, -1),
1779
+ ensemble_trg_start_logits, ensemble_trg_end_logits,
1780
+ ensemble_arg_start_logits, ensemble_arg_end_logits,
1781
+ pred_trg_spans, id2label
1782
+ )
1783
+ pred_events = shift_bidx(pred_events, batch_idx)
1784
+ refactor_events(pred_events, all_pred)
1785
+
1786
+ gold_events = shift_bidx(gold_events, batch_idx)
1787
+ refactor_events(gold_events, all_gold)
1788
+
1789
+ # ===== GLOBAL EVAL =====
1790
+ final_score = {}
1791
+ for eval_type in eval_types:
1792
+ score = eval_fn(list_to_tuple(all_pred[eval_type]), list_to_tuple(all_gold[eval_type]))
1793
+ final_score[eval_type] = score
1794
+
1795
+ analyze_result = analyzer.analyze(list_to_tuple(all_pred['Trg-I']), list_to_tuple(all_gold['Trg-I']))
1796
+
1797
+ # ===== PREDICT =====
1798
+ predictions = []
1799
+ for input_ids in list_input_ids:
1800
+ predictions.append([tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)])
1801
+ for event in all_pred['Strict-Event']:
1802
+ bidx = event[0]
1803
+ trg = tokenizer.decode(event[1][0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
1804
+ trg_lb = event[1][1]
1805
+ predictions[bidx].append((trg, trg_lb))
1806
+
1807
+ for arg_infor in event[2]:
1808
+ arg = tokenizer.decode(arg_infor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
1809
+ arg_lb = arg_infor[1]
1810
+ predictions[bidx].append((arg, arg_lb))
1811
+
1812
+ return final_score, analyze_result, predictions
1813
+
1814
+ # %% [code]
1815
+ with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f:
1816
+ data_train = json.load(f)
1817
+
1818
+ with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f:
1819
+ data_test = json.load(f)
1820
+
1821
+ print('Train:', len(data_train))
1822
+ print('Test:', len(data_test))
1823
+
1824
+ # %% [code]
1825
+ trigger_types = ['O'] + sorted(list(set([e['label'] for d in data_train + data_test for e in d['events']]))) # NBR : Neighbor relation
1826
+ # bio_trigger_types = [f'{prefix}-{trg}' for trg in trigger_types for prefix in ['B', 'I']]
1827
+ trigger_label2id = {l: i for i, l in enumerate(trigger_types)}
1828
+ trigger_id2label = {i: l for l, i in trigger_label2id.items()}
1829
+
1830
+ argument_types = ['O'] + sorted(list(set([a['role'] for d in data_train + data_test for e in d['events'] for a in e['arguments']])))
1831
+ # bio_argument_types = [f'{prefix}-{arg}' for arg in argument_types for prefix in ['B', 'I']]
1832
+ argument_label2id = {l: i for i, l in enumerate(argument_types)}
1833
+ argument_id2label = {i: l for l, i in argument_label2id.items()}
1834
+
1835
+ label2id = {
1836
+ 'Trg': trigger_label2id,
1837
+ 'Arg': argument_label2id,
1838
+ }
1839
+
1840
+ id2label = {
1841
+ 'Trg': trigger_id2label,
1842
+ 'Arg': argument_id2label,
1843
+ }
1844
+
1845
+ # %% [code]
1846
+ zero_events_idxes = []
1847
+ for idx, d in enumerate(data_train):
1848
+ if len(d['events']) == 0:
1849
+ zero_events_idxes.append(idx)
1850
+
1851
+ n_zero_events_samples = len(zero_events_idxes)
1852
+ n_has_events_samples = len(data_train) - n_zero_events_samples
1853
+
1854
+ random.seed(42)
1855
+ k = min(int(n_has_events_samples * zero_events_rate), len(zero_events_idxes))
1856
+ sampled_zero_events_idxes = random.sample(zero_events_idxes, k)
1857
+
1858
+ new_data_train = []
1859
+ for idx, d in enumerate(data_train):
1860
+ if len(d['events']) == 0:
1861
+ if idx in sampled_zero_events_idxes:
1862
+ new_data_train.append(d)
1863
+ else:
1864
+ new_data_train.append(d)
1865
+ data_train = new_data_train
1866
+
1867
+ print('Train:', len(data_train))
1868
+
1869
+ # %% [code]
1870
+ if debug_only:
1871
+ data_train = data_train[:20]
1872
+ data_test = data_test[:20]
1873
+
1874
+ print('Train:', len(data_train))
1875
+ print('Test:', len(data_test))
1876
+
1877
+ # %% [code]
1878
+ tokenizer = AutoTokenizer.from_pretrained(backbone_model_name)
1879
+
1880
+ # %% [code]
1881
+ print('Experiment name:', state_dict_save_name)
1882
+
1883
+ # %% [code]
1884
+ if not test_only:
1885
+ full_idxes = np.array(range(len(data_train)))
1886
+ training_logs, best_models, last_models = [], [], []
1887
+ start_training_time = time.time()
1888
+ for seed in SEEDS:
1889
+ kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed)
1890
+ for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)):
1891
+ if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx:
1892
+ continue
1893
+ set_seed(seed)
1894
+
1895
+ train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx]
1896
+
1897
+ trainset = KLTNDataset(data_train, train_idxes, label2id, tokenizer, **train_memory_params)
1898
+ valset = KLTNDataset(data_train, val_idxes, label2id, tokenizer, **val_memory_params)
1899
+
1900
+ generator = torch.Generator()
1901
+ generator.manual_seed(seed)
1902
+ train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params)
1903
+ val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params)
1904
+
1905
+ my_model = IEModel(
1906
+ num_trg_labels=len(trigger_label2id),
1907
+ num_arg_labels=len(argument_label2id),
1908
+ **model_params
1909
+ )
1910
+ total_params = sum(p.numel() for p in my_model.parameters())
1911
+ print(f"Total params: {total_params:,}")
1912
+
1913
+ # optimizer, scheduler = configure_optimizers(my_model, optim_params, scheduler_params)
1914
+ encoder_params = set(map(id, my_model.encoder.parameters()))
1915
+ other_params = [
1916
+ p for p in my_model.parameters()
1917
+ if id(p) not in encoder_params
1918
+ ]
1919
+ optimizer = optim.AdamW([
1920
+ {"params": my_model.encoder.parameters(), "lr": 2e-5},
1921
+ {"params": other_params}
1922
+ ], lr=5e-4)
1923
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
1924
+
1925
+ loss_fn = CustomLoss(
1926
+ **loss_func_params
1927
+ )
1928
+ eval_fn = CustomEvalFn(**eval_func_params)
1929
+ trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}'
1930
+ trainer = Trainer(**trainer_params)
1931
+
1932
+ print(f'Start Training Fold {fold_idx}...')
1933
+ training_log, best_model, last_model = trainer.fit(
1934
+ my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, eval_fn,
1935
+ start_epoch=1, start_training_time=start_training_time, id2label=id2label
1936
+ )
1937
+
1938
+ training_logs.append(training_log)
1939
+ best_models.append(best_model)
1940
+ last_models.append(last_model)
1941
+
1942
+ # %% [code]
1943
+ def load_all_state_dicts(folder):
1944
+ files = []
1945
+
1946
+ for file in os.listdir(folder):
1947
+ if file.endswith(".pt") or file.endswith(".pth"):
1948
+ m = re.search(r"f(\d+)", file) # tìm f<số>
1949
+ if m:
1950
+ fold = int(m.group(1))
1951
+ files.append((fold, file))
1952
+
1953
+ # sort theo fold
1954
+ files.sort(key=lambda x: x[0])
1955
+
1956
+ state_dicts = []
1957
+ for fold, file in files:
1958
+ path = os.path.join(folder, file)
1959
+ print(f"Loading fold {fold}: {file}")
1960
+ state_dict = torch.load(path, map_location="cpu")
1961
+ state_dicts.append(state_dict)
1962
+
1963
+ return state_dicts
1964
+
1965
+ if test_only:
1966
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"])
1967
+ get_ipython().system('rm -rf .cache .gitattributes')
1968
+
1969
+ best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s")
1970
+ last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts")
1971
+
1972
+ # %% [code]
1973
+ os.makedirs(f'{checkpoints_dir}/results', exist_ok=True)
1974
+ testset = KLTNDataset(data_test, range(len(data_test)), label2id, tokenizer, **val_memory_params)
1975
+ generator = torch.Generator()
1976
+ test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params)
1977
+ eval_fn = CustomEvalFn(**eval_func_params)
1978
+ analyzer = SpanErrorAnalyzer()
1979
+ my_model = IEModel(
1980
+ num_trg_labels=len(trigger_label2id),
1981
+ num_arg_labels=len(argument_label2id),
1982
+ **model_params
1983
+ )
1984
+ total_params = sum(p.numel() for p in my_model.parameters())
1985
+ print(f"Total params: {total_params:,}")
1986
+
1987
+ # %% [code]
1988
+ start_time = time.time()
1989
+
1990
+ best_score, best_analyze_result, best_pred_test = test(my_model, best_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
1991
+ last_score, last_analyze_result, last_pred_test = test(my_model, last_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
1992
+
1993
+ result_test = {"Best model": best_score, "Last model": last_score}
1994
+ analyze_result = {"Best model": best_analyze_result, "Last model": last_analyze_result}
1995
+ analyze_result_sumary = {"Best model": best_analyze_result['summary'], "Last model": last_analyze_result['summary']}
1996
+ pred_test = {"Best model": best_pred_test, "Last model": last_pred_test}
1997
+
1998
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f:
1999
+ json.dump(result_test, f, ensure_ascii=False, indent=2)
2000
+
2001
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_error_analyze_result.json", "w", encoding="utf-8") as f:
2002
+ json.dump(analyze_result, f, ensure_ascii=False, indent=2)
2003
+
2004
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_pred_test.json", "w", encoding="utf-8") as f:
2005
+ json.dump(pred_test, f, ensure_ascii=False, indent=2)
2006
+
2007
+ print('Test:', time.time() - start_time, 's --> Done!')
2008
+ print(json.dumps(analyze_result_sumary, ensure_ascii=False, indent=4))
2009
+
2010
+ # %% [code]
2011
+ best_pred_test[:10]
2012
+
2013
+ # %% [code]
2014
+ last_pred_test[:10]
2015
+
2016
+ # %% [code]
2017
+ def dict_to_df(data):
2018
+ row_tuples = []
2019
+ row_values = []
2020
+
2021
+ metrics = ["precision", "recall", "f1"]
2022
+
2023
+ # Lấy model đầu tiên
2024
+ first_model = next(iter(data.values()))
2025
+
2026
+ # eval_keys
2027
+ eval_keys = list(first_model.keys())
2028
+
2029
+ for eval_key in eval_keys:
2030
+ row_tuples.append(eval_key)
2031
+ row = {}
2032
+
2033
+ for model_name, model_data in data.items():
2034
+ for metric in metrics:
2035
+ row[(model_name, metric)] = model_data[eval_key][metric]
2036
+
2037
+ row_values.append(row)
2038
+
2039
+ # ===== DataFrame =====
2040
+ df = pd.DataFrame(row_values)
2041
+
2042
+ # MultiIndex columns
2043
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
2044
+
2045
+ # Index
2046
+ df.index = pd.Index(row_tuples, name="evaluation")
2047
+
2048
+ # ===== Sort =====
2049
+ sort_keys = []
2050
+ if ("Best model", "f1") in df.columns:
2051
+ sort_keys.append(("Best model", "f1"))
2052
+ if ("Last model", "f1") in df.columns:
2053
+ sort_keys.append(("Last model", "f1"))
2054
+
2055
+ if sort_keys:
2056
+ df = df.sort_values(by=sort_keys, ascending=False)
2057
+
2058
+ return df
2059
+
2060
+ result_test_df = dict_to_df(result_test)
2061
+ result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx")
2062
+ result_test_df
2063
+
2064
+ # %% [code]
2065
+ key = ("Best model", "f1")
2066
+ result_test_df_best = result_test_df.sort_values(by=key, ascending=False).groupby(level="evaluation").head(1)
2067
+ result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx")
2068
+ result_test_df_best
2069
+
2070
+ # %% [code]
2071
+ def get_avg_best_score(logs):
2072
+ return float(np.mean([list(log.values())[-1]['best_score'] for log in logs]))
2073
+
2074
+ def get_avg_log(logs, epochs):
2075
+ avg_log = {}
2076
+
2077
+ for epoch in range(1, epochs + 1):
2078
+ val_score = 0.0
2079
+ train_loss = 0.0
2080
+ n_eval = 0
2081
+
2082
+ for idx in range(len(logs)):
2083
+ log = logs[idx].get(epoch, logs[idx].get(str(epoch)))
2084
+ if log is None:
2085
+ continue
2086
+
2087
+ val_score += log.get('val_score', 0.0)
2088
+ train_loss += log.get('train_loss', 0.0)
2089
+ n_eval += 1
2090
+
2091
+ if n_eval == 0:
2092
+ continue
2093
+
2094
+ avg_log[epoch] = {
2095
+ 'train_loss': train_loss / n_eval,
2096
+ 'val_score': val_score / n_eval if val_score != 0 else float('inf')
2097
+ }
2098
+
2099
+ return avg_log
2100
+
2101
+ def parse_label_key(label: str):
2102
+ try:
2103
+ first = float(label.split('_', 1)[0]) # số đầu: trước dấu _
2104
+ last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0])
2105
+ return first, last
2106
+ except:
2107
+ return (0, 0)
2108
+
2109
+ def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)):
2110
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
2111
+
2112
+ # ===== Plot Train Loss =====
2113
+ for name, log in logs_dict.items():
2114
+ epochs = sorted(log.keys())
2115
+ train_loss = [log[e]['train_loss'] for e in epochs]
2116
+ axes[0].plot(epochs, train_loss, label=name)
2117
+
2118
+ axes[0].set_xlabel('Epoch')
2119
+ axes[0].set_ylabel('Train Loss')
2120
+ axes[0].set_title('Training Loss')
2121
+ axes[0].grid(True)
2122
+
2123
+ # ===== Plot Validation Score =====
2124
+ for name, log in logs_dict.items():
2125
+ epochs = sorted(log.keys())
2126
+ val_score = [log[e]['val_score'] for e in epochs]
2127
+ axes[1].plot(epochs, val_score, label=name)
2128
+
2129
+ axes[1].set_xlabel('Epoch')
2130
+ axes[1].set_ylabel('Validation Score')
2131
+ axes[1].set_title('Validation Score')
2132
+ axes[1].grid(True)
2133
+
2134
+ # ===== Shared Legend =====
2135
+ handles, labels = axes[0].get_legend_handles_labels()
2136
+ pairs = list(zip(handles, labels))
2137
+ pairs_sorted = sorted(
2138
+ pairs,
2139
+ key=lambda x: parse_label_key(x[1])
2140
+ )
2141
+ handles_sorted, labels_sorted = zip(*pairs_sorted)
2142
+
2143
+ axes[0].legend(
2144
+ handles_sorted,
2145
+ labels_sorted,
2146
+ loc='center left',
2147
+ bbox_to_anchor=(1.01, 0.5),
2148
+ borderaxespad=0.
2149
+ )
2150
+
2151
+ plt.tight_layout(rect=[0, 0, 1, 1])
2152
+
2153
+ if save_path is not None:
2154
+ os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None
2155
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
2156
+
2157
+ plt.show()
2158
+
2159
+ # %% [code]
2160
+ if not test_only:
2161
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*.json"])
2162
+ get_ipython().system('rm -rf .cache .gitattributes')
2163
+
2164
+ # %% [code]
2165
+ if not test_only:
2166
+ experiments = {}
2167
+ for experiment in os.listdir(pretrained_dir):
2168
+ if '.virtual_documents' in experiment:
2169
+ continue
2170
+ experiment_logs = []
2171
+ try:
2172
+ for seed in SEEDS:
2173
+ for fold_idx in range(nfolds):
2174
+ with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f:
2175
+ experiment_log = json.load(f)
2176
+ experiment_logs.append(experiment_log)
2177
+ except:
2178
+ pass
2179
+ experiments[experiment] = get_avg_log(experiment_logs, 1000)
2180
+ experiments[state_dict_save_name] = get_avg_log(training_logs, 1000)
2181
+
2182
+ # %% [code]
2183
+ if not test_only:
2184
+ score = get_avg_best_score(training_logs)
2185
+ state_dict_save_name, score
2186
+
2187
+ # %% [code]
2188
+ if not test_only:
2189
+ plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5))
2190
+
1_pointer_base_actions_4/lasts/1_pointer_base_actions_4_s26092004_f0_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0d2a61cfb2456f4fcf74c27b9024e1ba887982da4a754a121026f992422a803
3
+ size 566179472
1_pointer_base_actions_4/logs/1_pointer_base_actions_4_log_plot.jpg ADDED

Git LFS Details

  • SHA256: 346a325e8100065c60c8e62cebc876d8ba8c338a8ba32c34c1bacc0035ccf2a2
  • Pointer size: 131 Bytes
  • Size of remote file: 554 kB
1_pointer_base_actions_4/logs/1_pointer_base_actions_4_s26092004_f0_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 0.0005], "train_loss": 0.5427724719047546, "total": 0.5427724685004235, "trg_start_loss": 0.08683479213068354, "trg_end_loss": 0.08729431359142441, "arg_start_loss": 0.19451468542492287, "arg_end_loss": 0.17412894323742287}, "2": {"lr": [1.988303923565381e-05, 0.0004969282409784868], "train_loss": 0.3449680209159851, "total": 0.3449680147420287, "trg_start_loss": 0.056600022292898795, "trg_end_loss": 0.05468429462030342, "arg_start_loss": 0.12636818013371248, "arg_end_loss": 0.10731538106026516}, "3": {"lr": [1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 0.30685120820999146, "total": 0.3068512051518635, "trg_start_loss": 0.051538380625633394, "trg_end_loss": 0.04912136184973038, "arg_start_loss": 0.11248089475530161, "arg_end_loss": 0.09371035373684052}, "4": {"lr": [1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 0.28036218881607056, "total": 0.28036218368072363, "trg_start_loss": 0.047697386690861716, "trg_end_loss": 0.04493536122886163, "arg_start_loss": 0.10297534865077444, "arg_end_loss": 0.08475420528090588}, "5": {"lr": [1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 0.2587788701057434, "total": 0.258778870278845, "trg_start_loss": 0.04411329212170256, "trg_end_loss": 0.04110915709326597, "arg_start_loss": 0.09573212113061925, "arg_end_loss": 0.07782418914824464}, "6": {"lr": [1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 0.23932349681854248, "total": 0.23932348585544228, "trg_start_loss": 0.04078841601775207, "trg_end_loss": 0.03751262875286446, "arg_start_loss": 0.08920418812252315, "arg_end_loss": 0.07181828250497262}, "7": {"lr": [1.60839598967785e-05, 0.00039715242044697206], "train_loss": 0.22157159447669983, "total": 0.22157159087041686, "trg_start_loss": 0.037683773133092326, "trg_end_loss": 0.034125395458180816, "arg_start_loss": 0.08317480982792458, "arg_end_loss": 0.06658761983688664}, "8": {"lr": [1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 0.20351392030715942, "total": 0.20351391447940617, "trg_start_loss": 0.034693161261854885, "trg_end_loss": 0.030822643121237444, "arg_start_loss": 0.07721389667108468, "arg_end_loss": 0.060784390681249245, "val_score": 0.3701047453894544, "best_score": 0.3701047453894544, "new_best_model": true, "precision": 0.4424849613204414, "recall": 0.3212647367659032, "f1": 0.3701047453894544}, "9": {"lr": [1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 0.1903642863035202, "total": 0.19036428002858785, "trg_start_loss": 0.03195148700666105, "trg_end_loss": 0.02785607344414119, "arg_start_loss": 0.07295295737489034, "arg_end_loss": 0.05760369757830462, "val_score": 0.3692381361873917, "best_score": 0.3701047453894544, "new_best_model": false, "precision": 0.440200826484632, "recall": 0.3213434695036083, "f1": 0.3692381361873917}, "10": {"lr": [1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 0.17567050457000732, "total": 0.17567049453011557, "trg_start_loss": 0.029223659031619637, "trg_end_loss": 0.02512053298580912, "arg_start_loss": 0.06797870292552789, "arg_end_loss": 0.05334783962135278, "val_score": 0.3701026078138847, "best_score": 0.3701047453894544, "new_best_model": false, "precision": 0.4400244761923354, "recall": 0.3227855835757088, "f1": 0.3701026078138847}, "11": {"lr": [1.0500000000000003e-05, 0.0002505], "train_loss": 0.16482838988304138, "total": 0.16482839372012645, "trg_start_loss": 0.026876604683978515, "trg_end_loss": 0.022636180925692386, "arg_start_loss": 0.06467665961266487, "arg_end_loss": 0.05063901681521509, "val_score": 0.3699961152878146, "best_score": 0.3701047453894544, "new_best_model": false, "precision": 0.441237750699725, "recall": 0.32197137031189327, "f1": 0.3699961152878146}, "12": {"lr": [9.013872582117811e-06, 0.00021146960097246246], "train_loss": 0.151790589094162, "total": 0.1517905930466481, "trg_start_loss": 0.024764006504823254, "trg_end_loss": 0.020361362099070517, "arg_start_loss": 0.05991982444327429, "arg_end_loss": 0.046745385228145045, "val_score": 0.3668802698978058, "best_score": 0.3701047453894544, "new_best_model": false, "precision": 0.44013435276182367, "recall": 0.3179195753312056, "f1": 0.3668802698978058}, "13": {"lr": [7.564338553438001e-06, 0.00017340025990345064], "train_loss": 0.14348207414150238, "total": 0.1434820716171043, "trg_start_loss": 0.023153757425845536, "trg_end_loss": 0.01871082051102674, "arg_start_loss": 0.057142891011418125, "arg_end_loss": 0.04447455096914138, "val_score": 0.363958680694572, "best_score": 0.3701047453894544, "new_best_model": false, "precision": 0.4387899882878758, "recall": 0.3140481661345762, "f1": 0.363958680694572}, "14": {"lr": [6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 0.1350085288286209, "total": 0.13500852483285938, "trg_start_loss": 0.021544224756615667, "trg_end_loss": 0.017338526560514732, "arg_start_loss": 0.054125147890606094, "arg_end_loss": 0.04200059238961913, "val_score": 0.36283423225025196, "best_score": 0.3701047453894544, "new_best_model": false, "precision": 0.4418018029422258, "recall": 0.31077071962368646, "f1": 0.36283423225025196}, "15": {"lr": [4.916040103221507e-06, 0.00010384757955302797], "train_loss": 0.12671814858913422, "total": 0.1267181426027045, "trg_start_loss": 0.020146003420332164, "trg_end_loss": 0.01588509735964159, "arg_start_loss": 0.051025050884294834, "arg_end_loss": 0.03966203155960718, "val_score": 0.36074218029781185, "best_score": 0.3701047453894544, "new_best_model": false, "precision": 0.44210523477838565, "recall": 0.3076709658499786, "f1": 0.36074218029781185}, "16": {"lr": [3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 0.1200651153922081, "total": 0.12006511115843191, "trg_start_loss": 0.018968966208030503, "trg_end_loss": 0.014787089559254087, "arg_start_loss": 0.04851426367376573, "arg_end_loss": 0.03779484249384567, "val_score": 0.35989326351730144, "best_score": 0.3701047453894544, "new_best_model": false, "precision": 0.4444530405417498, "recall": 0.30540623589526067, "f1": 0.35989326351730144}, "17": {"lr": [2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 0.11547578126192093, "total": 0.11547578292802366, "trg_start_loss": 0.018174899859211453, "trg_end_loss": 0.013982757132604993, "arg_start_loss": 0.04677155264811041, "arg_end_loss": 0.03654645788704206, "val_score": 0.3589282160684469, "best_score": 0.3701047453894544, "new_best_model": false, "precision": 0.447412934145732, "recall": 0.30264166530897607, "f1": 0.3589282160684469}, "18": {"lr": [2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 0.11075928062200546, "total": 0.1107592808888704, "trg_start_loss": 0.017650351288934874, "trg_end_loss": 0.013565790964811607, "arg_start_loss": 0.04475618124238981, "arg_end_loss": 0.03478696662481849, "val_score": 0.3580132020836549, "best_score": 0.3701047453894544, "new_best_model": false, "precision": 0.44968723667318006, "recall": 0.30041886219716696, "f1": 0.3580132020836549}}
1_pointer_base_actions_4/r1s/1_pointer_base_actions_4_s26092004_f0_r1_vs0.37010_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d475136c5c40b13f37444afdf3ff2a6bbacf31ac303b5194d67be453a70e846e
3
+ size 566181336
1_pointer_base_actions_4/results/1_pointer_base_actions_4_error_analyze_result.json ADDED
The diff for this file is too large to render. See raw diff
 
1_pointer_base_actions_4/results/1_pointer_base_actions_4_pred_test.json ADDED
The diff for this file is too large to render. See raw diff
 
1_pointer_base_actions_4/results/1_pointer_base_actions_4_test.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Best model": {
3
+ "Trg-I": {
4
+ "precision": 0.30360706062885096,
5
+ "recall": 0.5674125071699156,
6
+ "f1": 0.39556043941346375
7
+ },
8
+ "Trg-C": {
9
+ "precision": 0.2801227935529085,
10
+ "recall": 0.5233725265255997,
11
+ "f1": 0.36492701005460676
12
+ },
13
+ "Arg-I": {
14
+ "precision": 0.2434387371095042,
15
+ "recall": 0.427994500772961,
16
+ "f1": 0.31035234280272317
17
+ },
18
+ "Arg-C": {
19
+ "precision": 0.23473926005560103,
20
+ "recall": 0.4122746781112341,
21
+ "f1": 0.29914981942410246
22
+ },
23
+ "Soft-Event": {
24
+ "precision": 0.21621621621611053,
25
+ "recall": 0.37974248927006027,
26
+ "f1": 0.27554420141159813
27
+ },
28
+ "Strict-Event": {
29
+ "precision": 0.08779739063685679,
30
+ "recall": 0.16403785488911948,
31
+ "f1": 0.11437712003311234
32
+ }
33
+ },
34
+ "Last model": {
35
+ "Trg-I": {
36
+ "precision": 0.2986111111106503,
37
+ "recall": 0.5550774526662218,
38
+ "f1": 0.3883202844330266
39
+ },
40
+ "Trg-C": {
41
+ "precision": 0.2746489739233534,
42
+ "recall": 0.5104674505290782,
43
+ "f1": 0.35714285259322476
44
+ },
45
+ "Arg-I": {
46
+ "precision": 0.24609130706678856,
47
+ "recall": 0.40573981783776786,
48
+ "f1": 0.30636475232302995
49
+ },
50
+ "Arg-C": {
51
+ "precision": 0.2377006462371077,
52
+ "recall": 0.3915021459224107,
53
+ "f1": 0.29580387363048766
54
+ },
55
+ "Soft-Event": {
56
+ "precision": 0.21867834062944616,
57
+ "recall": 0.3601716738194333,
58
+ "f1": 0.2721317807580729
59
+ },
60
+ "Strict-Event": {
61
+ "precision": 0.08347477241153607,
62
+ "recall": 0.15514769142484902,
63
+ "f1": 0.10854734697574614
64
+ }
65
+ }
66
+ }
1_pointer_base_actions_4/results/1_pointer_base_actions_4_test_df.xlsx ADDED
Binary file (5.66 kB). View file
 
1_pointer_base_actions_4/results/1_pointer_base_actions_4_test_df_best.xlsx ADDED
Binary file (5.66 kB). View file