SS3M commited on
Commit
bf23767
·
verified ·
1 Parent(s): e2a67cd

Upload 20_entities_top_50_25's state dict

Browse files
.gitattributes CHANGED
@@ -83,3 +83,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
83
  12_13_lr_16/logs/12_13_lr_16_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
84
  8_entities_top_span_self_ensemble_no_weight_20/logs/8_entities_top_span_self_ensemble_no_weight_20_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
85
  20_weight_5_21/logs/20_weight_5_21_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
 
 
83
  12_13_lr_16/logs/12_13_lr_16_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
84
  8_entities_top_span_self_ensemble_no_weight_20/logs/8_entities_top_span_self_ensemble_no_weight_20_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
85
  20_weight_5_21/logs/20_weight_5_21_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
86
+ 20_entities_top_50_25/logs/20_entities_top_50_25_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
20_entities_top_50_25/20_entities_top_50_25.py ADDED
@@ -0,0 +1,2419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [code]
2
+ get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch] pytorch-crf')
3
+
4
+ # %% [code]
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
12
+ import torch.nn.functional as F
13
+ import albumentations as albu
14
+ from transformers import AutoTokenizer, AutoModel
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from positional_encodings.torch_encodings import PositionalEncoding1D
18
+ from torchcrf import CRF
19
+
20
+ from sklearn.metrics import f1_score
21
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
22
+ from scipy.spatial.transform import Rotation as R
23
+ from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold
24
+ from sklearn.metrics import precision_recall_fscore_support
25
+ from timm.utils import ModelEmaV3
26
+ import timm
27
+
28
+ import os
29
+ import gc
30
+ import json
31
+ from pathlib import Path
32
+ import pickle
33
+ from tqdm.auto import tqdm
34
+ import copy
35
+ import numpy as np
36
+ import pandas as pd
37
+ import polars as pl
38
+ from PIL import Image
39
+ import time
40
+ from tqdm import tqdm
41
+ from matplotlib import pyplot as plt
42
+ import seaborn as sns
43
+ from multiprocessing import Manager as MemoryManager
44
+ from functools import lru_cache
45
+ import shutil
46
+ import glob
47
+ import cv2
48
+ import random
49
+ import re
50
+ import joblib
51
+ import math
52
+ from huggingface_hub import HfApi, snapshot_download
53
+ import evaluate
54
+ from underthesea import word_tokenize as vi_tokenize_tool
55
+ import spacy
56
+ en_tokenize_tool = spacy.load("en_core_web_sm")
57
+ from collections import defaultdict, Counter
58
+
59
+ # %% [code]
60
+ # Global config
61
+ SEEDS = [26092004]
62
+ topk = 1
63
+ nfolds = 5
64
+ only_fold_idx = 0
65
+ test_only = 0
66
+ debug_only = 0
67
+
68
+ # Config thư mục
69
+ dataset = 'kltn/only_entities' # conll003, ontonotes, phoner, vietbio, vietmed, vimed, kltn/only_entities, kltn/raw
70
+ root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' ## Thư mục chứa file train, val, test
71
+ train_dir = f'{root_dir}'
72
+ # val_dir = f'{root_dir}/val'
73
+ test_dir = f'{root_dir}'
74
+
75
+ # Config checkpoints
76
+
77
+ # Config training
78
+ epochs = 18 if not debug_only else 2
79
+ batch_size = 32
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+ # # Thêm biến toàn cục nào đó vào đây
82
+ repo_name = 'SS3M/kltn-experiments'
83
+ state_dict_save_name = "20_entities_top_50_25"
84
+ checkpoints_dir = state_dict_save_name
85
+ pretrained_dir = "/kaggle/working"
86
+ os.makedirs(f'{checkpoints_dir}', exist_ok=True)
87
+
88
+ backbone_model_name = "bert-base-uncased" if dataset in ["conll003", "ontonotes"] else "vinai/phobert-base"
89
+ word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == dataset in ["conll003", "ontonotes"] else vi_tokenize_tool(text)
90
+ max_len_dict = {
91
+ 'kltn/raw': 256,
92
+ 'kltn/only_entities': 68,
93
+ 'conll003': 46,
94
+ 'ontonotes': 61,
95
+ 'phoner': 68,
96
+ 'vietbio': 125,
97
+ 'vietmed': 36,
98
+ 'vimed': 100,
99
+ }
100
+ zero_entities_rate_dict = {
101
+ 'kltn/raw': 1000,
102
+ 'kltn/only_entities': 0.2,
103
+ 'conll003': 1000, # mean keep all zero-entities samples
104
+ 'ontonotes': 1000,
105
+ 'phoner': 1000,
106
+ 'vietbio': 1000,
107
+ 'vietmed': 1000,
108
+ 'vimed': 1000,
109
+ }
110
+
111
+ max_len = max_len_dict[dataset]
112
+ max_n_parts = 3 if dataset in ['kltn/raw'] else 1
113
+ max_span_len = 10
114
+ zero_entities_rate = zero_entities_rate_dict[dataset]
115
+
116
+ # Trainer
117
+ trainer_params = {
118
+ "training_time": "00:11:30:00",
119
+ "eval_mode": "max",
120
+ "topk": topk,
121
+ "save_name": state_dict_save_name,
122
+ "save_best": True,
123
+ "save_last": True,
124
+ "device": device,
125
+ "logging": True,
126
+ "logging_file": True,
127
+ "checkpoints_dir": checkpoints_dir,
128
+ "early_stopping": 30,
129
+ "eval_from_ratio": 0.4,
130
+ "eval_every": 1,
131
+ "schedule_in_step": False,
132
+ "use_ema": True,
133
+ "ema_from_ratio": 0.3,
134
+ "ema_decay": 0.9995,
135
+ "max_grad_norm": 200.0,
136
+ "return_best": True,
137
+ "return_last": True,
138
+ }
139
+
140
+ # Memory
141
+ train_memory_params = {
142
+ 'max_len': max_len,
143
+ 'max_n_parts': max_n_parts,
144
+ 'max_span_len': max_span_len,
145
+ 'weight_rate': None,
146
+ }
147
+ val_memory_params = {
148
+ 'max_len': max_len,
149
+ 'max_n_parts': max_n_parts,
150
+ 'max_span_len': max_span_len,
151
+ 'weight_rate': None,
152
+ }
153
+
154
+ # Data Loader
155
+ def seed_worker(worker_id):
156
+ worker_seed = torch.initial_seed() % 2**32
157
+ np.random.seed(worker_seed)
158
+ random.seed(worker_seed)
159
+
160
+ train_loader_params = {
161
+ 'batch_size': batch_size,
162
+ 'shuffle': True,
163
+ 'pin_memory':True,
164
+ 'num_workers': 2,
165
+ 'drop_last': False,
166
+ 'worker_init_fn': seed_worker,
167
+ 'persistent_workers': False,
168
+ }
169
+ val_loader_params = {
170
+ 'batch_size': batch_size,
171
+ 'shuffle': False,
172
+ 'pin_memory':True,
173
+ 'num_workers': 1,
174
+ 'drop_last': False,
175
+ 'worker_init_fn': seed_worker,
176
+ 'persistent_workers': False,
177
+ }
178
+
179
+ # Model
180
+ model_params = {
181
+ 'backbone_model_name': backbone_model_name,
182
+ 'max_span_len': max_span_len,
183
+ 'topk_spans': 50,
184
+ 'keep_neighbor': 2,
185
+ }
186
+
187
+ # Loss Func
188
+ loss_func_params = {
189
+ 'lambda_ce': 1.0,
190
+ }
191
+ eval_func_params = {}
192
+
193
+ # Optim
194
+ optim_params = {
195
+ 'name': 'AdamW',
196
+ 'lr': 1e-4,
197
+ 'weight_decay': 1e-4,
198
+ }
199
+ scheduler_params = {
200
+ 'name': 'CosineAnnealingLR',
201
+ 'T_max': 20, # Số epoch để hoàn thành một chu kỳ giảm LR
202
+ 'eta_min': 1e-6 # Learning rate nhỏ nhất trong chu kỳ
203
+ }
204
+
205
+ # %% [code]
206
+ def set_seed(seed=42):
207
+ random.seed(seed)
208
+ np.random.seed(seed)
209
+ torch.manual_seed(seed)
210
+ torch.cuda.manual_seed(seed)
211
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
212
+ torch.use_deterministic_algorithms(False)
213
+ torch.backends.cudnn.deterministic = True
214
+ torch.backends.cudnn.benchmark = False
215
+ os.environ['PYTHONHASHSEED'] = str(seed)
216
+
217
+ # %% [code]
218
+ class CustomLoss(nn.Module):
219
+ def __init__(self, lambda_ce=1.0):
220
+ super().__init__()
221
+ self.lambda_ce = lambda_ce
222
+
223
+ self.ce = nn.CrossEntropyLoss(
224
+ ignore_index=-100,
225
+ reduction='none'
226
+ )
227
+
228
+ def forward(
229
+ self,
230
+ logits, labels, weights, # weights: (B, N)
231
+ start_logits, start_labels,
232
+ end_logits, end_labels,
233
+ ):
234
+ # =====================================
235
+ # SPAN LOSS
236
+ # =====================================
237
+
238
+ B, N, C = logits.shape
239
+
240
+ flat_logits = logits.reshape(-1, C)
241
+ flat_labels = labels.reshape(-1)
242
+ flat_weights = weights.reshape(-1)
243
+
244
+ valid_mask = flat_labels != -100
245
+
246
+ if valid_mask.any():
247
+ ce_loss = self.ce(flat_logits, flat_labels) # (B*N,)
248
+
249
+ ce_loss = ce_loss[valid_mask]
250
+ valid_weights = flat_weights[valid_mask]
251
+
252
+ loss = (ce_loss * valid_weights).sum() / valid_weights.sum().clamp(min=1e-8)
253
+
254
+ else:
255
+ loss = logits.sum() * 0.0
256
+
257
+ # =====================================
258
+ # START LOSS
259
+ # =====================================
260
+
261
+ B, L, C = start_logits.shape
262
+
263
+ start_logits_flat = start_logits.reshape(B * L, C)
264
+ start_labels_flat = start_labels.reshape(-1)
265
+
266
+ start_loss = self.ce(start_logits_flat, start_labels_flat)
267
+ start_valid = start_labels_flat != -100
268
+
269
+ if start_valid.any():
270
+ start_loss = start_loss[start_valid].mean()
271
+ else:
272
+ start_loss = logits.sum() * 0.0
273
+
274
+ # =====================================
275
+ # END LOSS
276
+ # =====================================
277
+
278
+ end_logits_flat = end_logits.reshape(B * L, C)
279
+ end_labels_flat = end_labels.reshape(-1)
280
+
281
+ end_loss = self.ce(end_logits_flat, end_labels_flat)
282
+ end_valid = end_labels_flat != -100
283
+
284
+ if end_valid.any():
285
+ end_loss = end_loss[end_valid].mean()
286
+ else:
287
+ end_loss = logits.sum() * 0.0
288
+
289
+ return {
290
+ "total": loss + start_loss + end_loss,
291
+ "span_loss": loss,
292
+ "start_loss": start_loss,
293
+ "end_loss": end_loss,
294
+ }
295
+
296
+ # %% [code]
297
+ ## Viết eval_fn vào đây
298
+
299
+ # Bỏ hết eval_fn và trọng số vào đây
300
+ class CustomEvalFn(nn.Module):
301
+ def __init__(self):
302
+ super().__init__()
303
+
304
+ def compute_f1(self, tp, fp, fn):
305
+ precision = tp / (tp + fp + 1e-8)
306
+ recall = tp / (tp + fn + 1e-8)
307
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
308
+ return precision, recall, f1
309
+
310
+ def forward(self, pred, gold):
311
+ pred_set = set(pred)
312
+ gold_set = set(gold)
313
+
314
+ tp = len(pred_set & gold_set)
315
+ fp = len(pred_set - gold_set)
316
+ fn = len(gold_set - pred_set)
317
+
318
+ precision, recall, f1 = self.compute_f1(tp, fp, fn)
319
+
320
+ return {
321
+ f"precision": precision,
322
+ f"recall": recall,
323
+ f"f1": f1,
324
+ }
325
+
326
+ class SpanErrorAnalyzer:
327
+ def __init__(self, pad_token_id=0):
328
+ self.pad_token_id = pad_token_id
329
+
330
+ # ===== helper =====
331
+ def _to_set(self, data):
332
+ """
333
+ data: list of (b, tuple(ids))
334
+ -> dict[b] = set(tuple(ids))
335
+ """
336
+ res = defaultdict(set)
337
+ for b, ids in data:
338
+ ids = tuple([i for i in ids if i != self.pad_token_id])
339
+ if len(ids) > 0:
340
+ res[b].add(ids)
341
+ return res
342
+
343
+ def _iou(self, a, b):
344
+ """
345
+ a, b: tuple(ids)
346
+ """
347
+ set_a, set_b = set(a), set(b)
348
+ inter = len(set_a & set_b)
349
+ union = len(set_a | set_b)
350
+ if union == 0:
351
+ return 0.0
352
+ return inter / union
353
+
354
+ def _boundary_error(self, pred, gold):
355
+ """
356
+ đo lệch boundary dựa trên overlap prefix/suffix
357
+ """
358
+ # left match
359
+ left = 0
360
+ for i in range(min(len(pred), len(gold))):
361
+ if pred[i] == gold[i]:
362
+ left += 1
363
+ else:
364
+ break
365
+
366
+ # right match
367
+ right = 0
368
+ for i in range(1, min(len(pred), len(gold)) + 1):
369
+ if pred[-i] == gold[-i]:
370
+ right += 1
371
+ else:
372
+ break
373
+
374
+ return {
375
+ "left_match": left,
376
+ "right_match": right,
377
+ "pred_len": len(pred),
378
+ "gold_len": len(gold),
379
+ }
380
+
381
+ # ===== main =====
382
+ def analyze(self, preds, golds):
383
+ pred_map = self._to_set(preds)
384
+ gold_map = self._to_set(golds)
385
+
386
+ all_batches = set(pred_map.keys()) | set(gold_map.keys())
387
+
388
+ stats = Counter()
389
+
390
+ detailed_errors = []
391
+
392
+ for b in all_batches:
393
+ pset = pred_map.get(b, set())
394
+ gset = gold_map.get(b, set())
395
+
396
+ matched_gold = set()
397
+
398
+ # ===== check predictions =====
399
+ for p in pset:
400
+ if p in gset:
401
+ stats["exact_match"] += 1
402
+ matched_gold.add(p)
403
+ else:
404
+ # tìm gold gần nhất
405
+ best_iou = 0
406
+ best_g = None
407
+
408
+ for g in gset:
409
+ iou = self._iou(p, g)
410
+ if iou > best_iou:
411
+ best_iou = iou
412
+ best_g = g
413
+
414
+ if best_iou > 0:
415
+ stats["partial_match"] += 1
416
+
417
+ boundary = self._boundary_error(p, best_g)
418
+
419
+ detailed_errors.append({
420
+ "type": "boundary_error",
421
+ "batch": b,
422
+ "pred": p,
423
+ "gold": best_g,
424
+ "iou": best_iou,
425
+ **boundary
426
+ })
427
+ else:
428
+ if b not in gold_map:
429
+ stats["no_event_sample"] += 1
430
+ err_type = "no_event_sample"
431
+ else:
432
+ stats["completely_wrong"] += 1
433
+ err_type = "completely_wrong"
434
+
435
+ detailed_errors.append({
436
+ "type": err_type,
437
+ "batch": b,
438
+ "pred": p
439
+ })
440
+
441
+ # ===== check missing =====
442
+ for g in gset:
443
+ if g not in matched_gold:
444
+ # check if any pred overlaps
445
+ overlap = any(self._iou(p, g) > 0 for p in pset)
446
+
447
+ if overlap:
448
+ stats["miss_with_overlap"] += 1
449
+ else:
450
+ stats["miss"] += 1
451
+
452
+ detailed_errors.append({
453
+ "type": "miss",
454
+ "batch": b,
455
+ "gold": g
456
+ })
457
+
458
+ return {
459
+ "summary": {
460
+ "exact_match": (stats["exact_match"], stats["exact_match"] / len(preds)),
461
+ "partial_match": (stats["partial_match"], stats["partial_match"] / len(preds)),
462
+ "no_event_sample": (stats["no_event_sample"], stats["no_event_sample"] / len(preds)),
463
+ "completely_wrong": (stats["completely_wrong"], stats["completely_wrong"] / len(preds)),
464
+ "miss": (stats["miss"], stats["miss"] / len(golds)),
465
+ "miss_with_overlap": (stats["miss_with_overlap"], stats["miss_with_overlap"] / len(golds)),
466
+ },
467
+ "details": detailed_errors
468
+ }
469
+
470
+ # %% [code]
471
+ class DataParallelProxy(nn.DataParallel):
472
+
473
+ def __getattr__(self, name):
474
+ try:
475
+ return super().__getattr__(name)
476
+
477
+ except AttributeError:
478
+
479
+ attr = getattr(self.module, name)
480
+
481
+ if callable(attr):
482
+
483
+ def wrapper(*args, **kwargs):
484
+ return self._parallel_apply_method(
485
+ name,
486
+ *args,
487
+ **kwargs
488
+ )
489
+
490
+ return wrapper
491
+
492
+ return attr
493
+
494
+ # =========================================================
495
+ # parallel custom method
496
+ # =========================================================
497
+
498
+ def _parallel_apply_method(self, method_name, *inputs, **kwargs):
499
+
500
+ if not self.device_ids:
501
+ return getattr(self.module, method_name)(*inputs, **kwargs)
502
+
503
+ inputs_scattered, kwargs_scattered = self.scatter(
504
+ inputs,
505
+ kwargs,
506
+ self.device_ids
507
+ )
508
+
509
+ replicas = self.replicate(
510
+ self.module,
511
+ self.device_ids[:len(inputs_scattered)]
512
+ )
513
+
514
+ outputs = self.parallel_apply(
515
+ [getattr(replica, method_name) for replica in replicas],
516
+ inputs_scattered,
517
+ kwargs_scattered
518
+ )
519
+
520
+ return self._custom_gather(
521
+ outputs,
522
+ self.output_device
523
+ )
524
+
525
+ # =========================================================
526
+ # OVERRIDE FORWARD GATHER
527
+ # =========================================================
528
+
529
+ def gather(self, outputs, output_device):
530
+
531
+ return self._custom_gather(
532
+ outputs,
533
+ output_device
534
+ )
535
+
536
+ # =========================================================
537
+ # recursive gather
538
+ # =========================================================
539
+
540
+ def _custom_gather(self, outputs, output_device):
541
+
542
+ first = outputs[0]
543
+
544
+ # =====================================================
545
+ # tensor
546
+ # =====================================================
547
+
548
+ if torch.is_tensor(first):
549
+
550
+ return self._gather_tensor(
551
+ outputs,
552
+ output_device
553
+ )
554
+
555
+ # =====================================================
556
+ # tuple
557
+ # =====================================================
558
+
559
+ if isinstance(first, tuple):
560
+
561
+ return tuple(
562
+ self._custom_gather(
563
+ list(items),
564
+ output_device
565
+ )
566
+ for items in zip(*outputs)
567
+ )
568
+
569
+ # =====================================================
570
+ # list
571
+ # =====================================================
572
+
573
+ if isinstance(first, list):
574
+
575
+ # list[tensor]
576
+ if len(first) > 0 and torch.is_tensor(first[0]):
577
+
578
+ return self._gather_tensor_list(
579
+ outputs,
580
+ output_device
581
+ )
582
+
583
+ merged = []
584
+
585
+ for out in outputs:
586
+ merged.extend(out)
587
+
588
+ return merged
589
+
590
+ # =====================================================
591
+ # dict
592
+ # =====================================================
593
+
594
+ if isinstance(first, dict):
595
+
596
+ return {
597
+ k: self._custom_gather(
598
+ [o[k] for o in outputs],
599
+ output_device
600
+ )
601
+ for k in first.keys()
602
+ }
603
+
604
+ # =====================================================
605
+ # fallback
606
+ # =====================================================
607
+
608
+ return outputs
609
+
610
+ # =========================================================
611
+ # gather tensor with auto pad
612
+ # =========================================================
613
+
614
+ def _gather_tensor(self, tensors, output_device):
615
+
616
+ # move same device first
617
+ tensors = [
618
+ t.to(output_device)
619
+ for t in tensors
620
+ ]
621
+
622
+ # =====================================================
623
+ # fast path
624
+ # =====================================================
625
+
626
+ try:
627
+ return torch.cat(tensors, dim=0)
628
+
629
+ except RuntimeError:
630
+ pass
631
+
632
+ # =====================================================
633
+ # auto max shape
634
+ # =====================================================
635
+
636
+ max_shape = list(tensors[0].shape)
637
+
638
+ for t in tensors[1:]:
639
+
640
+ for d in range(len(max_shape)):
641
+
642
+ max_shape[d] = max(
643
+ max_shape[d],
644
+ t.shape[d]
645
+ )
646
+
647
+ # =====================================================
648
+ # pad tensors
649
+ # =====================================================
650
+
651
+ padded = []
652
+
653
+ for t in tensors:
654
+
655
+ pad = []
656
+
657
+ # reverse order for F.pad
658
+ for d in reversed(range(len(max_shape))):
659
+
660
+ # never pad batch dim
661
+ if d == 0:
662
+ pad.extend([0, 0])
663
+ continue
664
+
665
+ diff = max_shape[d] - t.shape[d]
666
+
667
+ pad.extend([0, diff])
668
+
669
+ t = F.pad(t, pad)
670
+
671
+ padded.append(t)
672
+
673
+ return torch.cat(padded, dim=0)
674
+
675
+ # =========================================================
676
+ # list[tensor]
677
+ # =========================================================
678
+
679
+ def _gather_tensor_list(self, outputs, output_device):
680
+
681
+ merged = []
682
+
683
+ for out in outputs:
684
+ merged.extend(out)
685
+
686
+ return self._gather_tensor(
687
+ merged,
688
+ output_device
689
+ )
690
+
691
+ # %% [code]
692
+ ## Viết cấu trúc model vào đây
693
+ def get_span_reprs(hidden, spans):
694
+ """
695
+ Args:
696
+ hidden: (B, L, H)
697
+ spans: (B, N, 2)
698
+
699
+ Return:
700
+ span_repr: (B, N, 4*H)
701
+ """
702
+
703
+ B, N, _ = spans.shape
704
+ H = hidden.size(-1)
705
+
706
+ batch_idx = torch.arange(B, device=hidden.device).unsqueeze(1)
707
+
708
+ start_idx = spans[..., 0] # (B, N)
709
+ end_idx = spans[..., 1] # (B, N)
710
+ start_h = hidden[batch_idx, start_idx]
711
+ end_h = hidden[batch_idx, end_idx]
712
+
713
+ span_repr = torch.cat(
714
+ [start_h, end_h, end_h - start_h, end_h * start_h],
715
+ dim=-1
716
+ )
717
+
718
+ return span_repr
719
+
720
+ def filter_spans(
721
+ start_logits, # (B, L, C)
722
+ end_logits, # (B, L, C)
723
+ attn_mask, # (B, L)
724
+ max_span_length=10,
725
+ topk_spans=20,
726
+ keep_neighbor=1
727
+ ):
728
+ """
729
+ Return:
730
+ spans: (B, N, 2)
731
+ N là số span lớn nhất trong batch sau expand + remove duplicate.
732
+ Padding bằng (0, 0)
733
+ """
734
+
735
+ device = start_logits.device
736
+ B, L, C = start_logits.shape
737
+
738
+ # Clone
739
+ start_logits = start_logits.detach()
740
+ end_logits = end_logits.detach()
741
+
742
+ # Prob
743
+ prob_s = F.softmax(start_logits, dim=-1) # (B, L, C)
744
+ prob_e = F.softmax(end_logits, dim=-1) # (B, L, C)
745
+
746
+ all_batch_spans = []
747
+
748
+ for b in range(B):
749
+
750
+ valid_len = int(attn_mask[b].sum().item())
751
+
752
+ candidate_spans = []
753
+
754
+ # Enumerate all valid spans
755
+ for s in range(1, valid_len):
756
+
757
+ max_e = min(valid_len - 1, s + max_span_length - 1)
758
+
759
+ for e in range(s, max_e + 1):
760
+
761
+ length = e - s + 1
762
+
763
+ score = (
764
+ (1.0 - prob_s[b, s, 0].item())
765
+ + (1.0 - prob_e[b, e, 0].item())
766
+ - (length / max_span_length)
767
+ )
768
+ candidate_spans.append((score, s, e))
769
+
770
+ # Top-k
771
+ candidate_spans = sorted(
772
+ candidate_spans,
773
+ key=lambda x: x[0],
774
+ reverse=True
775
+ )[:topk_spans]
776
+
777
+ # Expand neighbors
778
+ final_spans = set()
779
+
780
+ for _, s, e in candidate_spans:
781
+
782
+ for ds in range(-keep_neighbor, keep_neighbor + 1):
783
+ for de in range(-keep_neighbor, keep_neighbor + 1):
784
+
785
+ ns = s + ds
786
+ ne = e + de
787
+
788
+ if ns <= 0:
789
+ continue
790
+
791
+ if ne >= valid_len:
792
+ continue
793
+
794
+ if ns > ne:
795
+ continue
796
+
797
+ if (ne - ns + 1) > max_span_length:
798
+ continue
799
+
800
+ final_spans.add((ns, ne))
801
+
802
+ final_spans = sorted(list(final_spans))
803
+
804
+ if len(final_spans) == 0:
805
+ final_spans = [(0, 0)]
806
+
807
+ all_batch_spans.append(final_spans)
808
+
809
+ # Padding
810
+ max_num_spans = max(len(x) for x in all_batch_spans)
811
+
812
+ padded_spans = []
813
+
814
+ for spans in all_batch_spans:
815
+
816
+ pad_size = max_num_spans - len(spans)
817
+
818
+ spans = spans + [(0, 0)] * pad_size
819
+
820
+ padded_spans.append(spans)
821
+
822
+ spans = torch.tensor(
823
+ padded_spans,
824
+ dtype=torch.long,
825
+ device=device
826
+ ) # (B, N, 2)
827
+
828
+ return spans
829
+
830
+ class MLP(nn.Module):
831
+ def __init__(self, in_size, hid_size, out_size):
832
+ super().__init__()
833
+ self.mlp = nn.Sequential(
834
+ nn.Linear(in_size, hid_size),
835
+ nn.ReLU(),
836
+ nn.Linear(hid_size, out_size)
837
+ )
838
+
839
+ def forward(self, x):
840
+ return self.mlp(x)
841
+
842
+ class IEModel(nn.Module):
843
+ def __init__(self, backbone_model_name, num_labels, max_span_len, topk_spans, keep_neighbor):
844
+ super().__init__()
845
+ self.max_span_len = max_span_len
846
+ self.topk_spans = topk_spans
847
+ self.keep_neighbor = keep_neighbor
848
+
849
+ self.encoder = AutoModel.from_pretrained(backbone_model_name)
850
+ hidden_size = self.encoder.config.hidden_size
851
+
852
+ self.start_classifier = MLP(hidden_size, hidden_size, num_labels)
853
+ self.end_classifier = MLP(hidden_size, hidden_size, num_labels)
854
+
855
+ self.span_classifier = MLP(4*hidden_size, hidden_size, num_labels)
856
+
857
+ def encode(self, input_ids, attention_mask):
858
+ B, n_parts, L = input_ids.shape
859
+ input_ids = input_ids.view(-1, L)
860
+ attention_mask = attention_mask.view(-1, L)
861
+
862
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
863
+ hidden_states = outputs.last_hidden_state # B * n_parts, L, H
864
+
865
+ hidden_states = hidden_states.view(B, n_parts, L, -1).reshape(B, n_parts*L, -1) # B, L, H
866
+ attention_mask = attention_mask.view(B, n_parts, L).reshape(B, n_parts*L) # B, L, H
867
+ return hidden_states, attention_mask
868
+
869
+ def get_token_logits(self, hidden_states):
870
+ start_logits = self.start_classifier(hidden_states) # B, N, classes
871
+ end_logits = self.end_classifier(hidden_states) # B, N, classes
872
+ return start_logits, end_logits
873
+
874
+ def get_logits(self, span_reprs):
875
+ logits = self.span_classifier(span_reprs) # N, classes
876
+ return logits
877
+
878
+ def forward(self, input_ids, attention_mask, spans=None):
879
+ hidden_states, attention_mask = self.encode(input_ids, attention_mask)
880
+ start_logits, end_logits = self.get_token_logits(hidden_states)
881
+ if spans is None:
882
+ spans = filter_spans(start_logits, end_logits, attention_mask, self.max_span_len, self.topk_spans, self.keep_neighbor)
883
+ span_reprs = get_span_reprs(hidden_states, spans)
884
+ logits = self.get_logits(span_reprs)
885
+ return start_logits, end_logits, logits, spans
886
+
887
+ def test_model():
888
+ model = DataParallelProxy(IEModel(backbone_model_name, 7, 10, 100, 2)).to(device)
889
+ model.eval()
890
+ total_params = sum(p.numel() for p in model.parameters())
891
+ print(f"Total params: {total_params:,}")
892
+
893
+ vocab_size = model.module.encoder.config.vocab_size
894
+ max_len = model.module.encoder.config.max_position_embeddings
895
+
896
+ bz = 32
897
+ i = torch.randint(0, vocab_size, (bz, 5, 10)).to(device)
898
+ a = torch.ones(bz, 5, 10).to(device)
899
+ s = torch.ones(bz, 3, 2, dtype=torch.long).to(device)
900
+
901
+ with torch.no_grad():
902
+ r = model(i, a)
903
+
904
+ if type(r) == tuple:
905
+ print([r[i].shape if type(r[i]) == type(torch.Tensor()) else len(r[i]) for i in range(len(r))])
906
+ else:
907
+ print(r.shape)
908
+
909
+ test_model()
910
+
911
+ # %% [code]
912
+ def configure_optimizers(network, optim_params, scheduler_params):
913
+ try:
914
+ optim_params = copy.copy(optim_params)
915
+ scheduler_params = copy.copy(scheduler_params)
916
+
917
+ optim_name = optim_params.pop('name')
918
+ scheduler_name = scheduler_params.pop('name')
919
+
920
+ optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None)
921
+ scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None)
922
+
923
+ if optimizer_cls is None:
924
+ raise ValueError(f"Optimizer '{optim_name}' is not available!")
925
+
926
+ optimizer = optimizer_cls(network.parameters(), **optim_params)
927
+
928
+ scheduler = None
929
+ if scheduler_params and scheduler_cls: # Chỉ tạo scheduler nếu có tham số
930
+ scheduler = scheduler_cls(optimizer, **scheduler_params)
931
+
932
+ return optimizer, scheduler
933
+
934
+ except KeyError as e:
935
+ raise ValueError(f"Missing {e} in config!!")
936
+
937
+ def freeze(self, model):
938
+ model.eval()
939
+ for param in model.parameters():
940
+ param.requires_grad = False
941
+
942
+ def unfreeze(self, model):
943
+ model.train()
944
+ for param in model.parameters():
945
+ param.requires_grad = True
946
+
947
+ def reduce_batch_size(loader, ratio=0.5):
948
+ new_bs = max(1, int(loader.batch_size * ratio))
949
+
950
+ shuffle = isinstance(loader.sampler, RandomSampler)
951
+
952
+ new_loader = DataLoader(
953
+ dataset=loader.dataset,
954
+ batch_size=new_bs,
955
+ shuffle=shuffle,
956
+ sampler=None if shuffle else loader.sampler,
957
+ num_workers=loader.num_workers,
958
+ collate_fn=loader.collate_fn,
959
+ pin_memory=loader.pin_memory,
960
+ drop_last=loader.drop_last,
961
+ timeout=loader.timeout,
962
+ worker_init_fn=loader.worker_init_fn,
963
+ multiprocessing_context=loader.multiprocessing_context,
964
+ generator=loader.generator,
965
+ prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None,
966
+ persistent_workers=loader.persistent_workers,
967
+ pin_memory_device=loader.pin_memory_device
968
+ )
969
+
970
+ return new_loader
971
+
972
+ def list_to_tuple(x):
973
+ if isinstance(x, (list, tuple)):
974
+ return tuple(list_to_tuple(i) for i in x)
975
+ return x
976
+
977
+ def fmt(x):
978
+ if isinstance(x, float):
979
+ return round(x, 5)
980
+ if isinstance(x, dict):
981
+ return {k: fmt(v) for k, v in x.items()}
982
+ if isinstance(x, list):
983
+ return [fmt(v) for v in x]
984
+ return x
985
+
986
+ class ModelEmaV3Proxy(ModelEmaV3):
987
+
988
+ def __getattr__(self, name):
989
+
990
+ try:
991
+ return super().__getattr__(name)
992
+
993
+ except AttributeError:
994
+
995
+ # tránh recursion
996
+ module = object.__getattribute__(self, "module")
997
+
998
+ return getattr(module, name)
999
+
1000
+ def align(
1001
+ all_spans, # (B, N, 2)
1002
+ pred_spans, # (B, M, 2)
1003
+ obj, # (B, N)
1004
+ pad_value=0
1005
+ ):
1006
+ """
1007
+ Return:
1008
+ align_obj: (B, M)
1009
+ """
1010
+
1011
+ B, N, _ = all_spans.shape
1012
+ M = pred_spans.shape[1]
1013
+
1014
+ device = all_spans.device
1015
+
1016
+ # =========================================================
1017
+ # Compare spans
1018
+ # =========================================================
1019
+
1020
+ # (B, M, 1, 2)
1021
+ pred_expand = pred_spans.unsqueeze(2)
1022
+
1023
+ # (B, 1, N, 2)
1024
+ all_expand = all_spans.unsqueeze(1)
1025
+
1026
+ # (B, M, N)
1027
+ match = (pred_expand == all_expand).all(dim=-1)
1028
+
1029
+ # =========================================================
1030
+ # Gather obj
1031
+ # =========================================================
1032
+
1033
+ # default = pad_value
1034
+ align_obj = torch.full(
1035
+ (B, M),
1036
+ pad_value,
1037
+ dtype=obj.dtype,
1038
+ device=device
1039
+ )
1040
+
1041
+ # matched index
1042
+ has_match = match.any(dim=-1) # (B, M)
1043
+
1044
+ matched_idx = match.float().argmax(dim=-1) # (B, M)
1045
+
1046
+ gathered = torch.gather(
1047
+ obj,
1048
+ dim=1,
1049
+ index=matched_idx
1050
+ ) # (B, M)
1051
+
1052
+ align_obj[has_match] = gathered[has_match]
1053
+
1054
+ # =========================================================
1055
+ # Force padding spans -> pad_value
1056
+ # =========================================================
1057
+
1058
+ pad_mask = (pred_spans == 0).all(dim=-1)
1059
+
1060
+ align_obj[pad_mask] = pad_value
1061
+
1062
+ return align_obj
1063
+
1064
+ def extract_spans(
1065
+ all_spans, # (B, N, 2)
1066
+ all_label, # (B, N)
1067
+ pred_spans # (B, M, 2)
1068
+ ):
1069
+ """
1070
+ Return:
1071
+ pred_list:
1072
+ [(bidx, (s, e)), ...]
1073
+
1074
+ gold_list:
1075
+ [(bidx, (s, e)), ...]
1076
+ """
1077
+
1078
+ # =========================================================
1079
+ # Gold spans
1080
+ # =========================================================
1081
+
1082
+ gold_mask = all_label > 0 # (B, N)
1083
+
1084
+ gold_indices = gold_mask.nonzero(as_tuple=False)
1085
+
1086
+ gold_list = [
1087
+ (
1088
+ int(b),
1089
+ (
1090
+ int(all_spans[b, n, 0]),
1091
+ int(all_spans[b, n, 1])
1092
+ )
1093
+ )
1094
+ for b, n in gold_indices
1095
+ ]
1096
+
1097
+ # =========================================================
1098
+ # Pred spans
1099
+ # =========================================================
1100
+
1101
+ pred_mask = (pred_spans > 0).all(dim=-1) # (B, M)
1102
+
1103
+ pred_indices = pred_mask.nonzero(as_tuple=False)
1104
+
1105
+ pred_list = [
1106
+ (
1107
+ int(b),
1108
+ (
1109
+ int(pred_spans[b, m, 0]),
1110
+ int(pred_spans[b, m, 1])
1111
+ )
1112
+ )
1113
+ for b, m in pred_indices
1114
+ ]
1115
+
1116
+ return gold_list, pred_list
1117
+
1118
+ def extract_entities(
1119
+ input_ids, # (B, L)
1120
+ start_logits, # (B, L, C)
1121
+ end_logits, # (B, L, C)
1122
+ logits, # (B, N, C)
1123
+ pred_spans, # (B, N, 2)
1124
+ id2label,
1125
+ alpha=1.0
1126
+ ):
1127
+ """
1128
+ Return:
1129
+ [
1130
+ (batch_idx, ([token_ids], label_name)),
1131
+ ...
1132
+ ]
1133
+ """
1134
+
1135
+ # =========================================================
1136
+ # Log-softmax
1137
+ # =========================================================
1138
+
1139
+ start_logprob = F.log_softmax(start_logits, dim=-1) # (B, L, C)
1140
+ end_logprob = F.log_softmax(end_logits, dim=-1) # (B, L, C)
1141
+ span_logprob = F.log_softmax(logits, dim=-1) # (B, N, C)
1142
+
1143
+ # =========================================================
1144
+ # Gather start/end score for pred spans
1145
+ # =========================================================
1146
+
1147
+ start_idx = pred_spans[..., 0] # (B, N)
1148
+ end_idx = pred_spans[..., 1] # (B, N)
1149
+
1150
+ B, N = start_idx.shape
1151
+ C = logits.shape[-1]
1152
+
1153
+ # (B, N, C)
1154
+ start_score = torch.gather(
1155
+ start_logprob,
1156
+ dim=1,
1157
+ index=start_idx.unsqueeze(-1).expand(-1, -1, C)
1158
+ )
1159
+
1160
+ # (B, N, C)
1161
+ end_score = torch.gather(
1162
+ end_logprob,
1163
+ dim=1,
1164
+ index=end_idx.unsqueeze(-1).expand(-1, -1, C)
1165
+ )
1166
+
1167
+ # =========================================================
1168
+ # Ensemble score
1169
+ # =========================================================
1170
+
1171
+ score = (
1172
+ span_logprob
1173
+ + alpha * (start_score + end_score)
1174
+ ) # (B, N, C)
1175
+
1176
+ # =========================================================
1177
+ # Predict label
1178
+ # =========================================================
1179
+
1180
+ pred_labels = score.argmax(dim=-1) # (B, N)
1181
+
1182
+ keep = (
1183
+ (pred_labels > 0) &
1184
+ (start_idx > 0) &
1185
+ (end_idx > 0)
1186
+ )
1187
+
1188
+ # =========================================================
1189
+ # Extract entities
1190
+ # =========================================================
1191
+
1192
+ results = []
1193
+
1194
+ for bidx in range(B):
1195
+
1196
+ valid_idxes = keep[bidx].nonzero(as_tuple=False).squeeze(-1)
1197
+
1198
+ for idx in valid_idxes:
1199
+
1200
+ lb = pred_labels[bidx, idx]
1201
+
1202
+ s, e = pred_spans[bidx, idx].tolist()
1203
+
1204
+ token_ids = input_ids[bidx, s:e+1].tolist()
1205
+
1206
+ results.append(
1207
+ (
1208
+ bidx,
1209
+ (
1210
+ token_ids,
1211
+ id2label[lb.item()]
1212
+ )
1213
+ )
1214
+ )
1215
+
1216
+ return results
1217
+
1218
+ class Trainer:
1219
+ def __init__(
1220
+ self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0,
1221
+ logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu',
1222
+ schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True
1223
+ ):
1224
+ self.ema_net = None
1225
+
1226
+ self.training_time = self._time_str_to_seconds(training_time)
1227
+ self.mode = eval_mode
1228
+ self.topk = topk
1229
+ self.device = device
1230
+ self.logging = logging if logging < epochs else 1
1231
+ self.logging_file = logging_file
1232
+ self.checkpoints_dir = checkpoints_dir
1233
+ self.early_stopping = early_stopping
1234
+ self.eval_from_ratio = eval_from_ratio
1235
+ self.eval_every = eval_every
1236
+ self.save_name = save_name
1237
+ self.save_best = save_best
1238
+ self.save_last = save_last
1239
+ self.return_best = return_best
1240
+ self.return_last = return_last
1241
+ self.max_grad_norm = max_grad_norm
1242
+ self.schedule_in_step = schedule_in_step
1243
+ self.use_ema = use_ema
1244
+ self.ema_from_ratio = ema_from_ratio
1245
+ self.ema_decay = ema_decay
1246
+
1247
+ self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]]
1248
+ self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0)
1249
+
1250
+ def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, id2label=None):
1251
+ if eval_fn is None:
1252
+ if self.mode == "max":
1253
+ eval_fn = lambda *x: -loss_fn(*x)
1254
+ else:
1255
+ eval_fn = lambda *x: loss_fn(*x)
1256
+
1257
+ if torch.cuda.device_count() > 1:
1258
+ network = DataParallelProxy(network)
1259
+ network = network.to(self.device)
1260
+
1261
+ if not start_training_time:
1262
+ start_training_time = time.time()
1263
+
1264
+ start_ema = int(epochs * self.ema_from_ratio)
1265
+ start_eval = int(epochs * self.eval_from_ratio)
1266
+
1267
+ if val_loader is None:
1268
+ print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!')
1269
+ else:
1270
+ model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc'
1271
+ start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else ''
1272
+ print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str)
1273
+
1274
+ training_log = {}
1275
+ for epoch in range(start_epoch, epochs+start_epoch):
1276
+ if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema:
1277
+ self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
1278
+
1279
+ try:
1280
+ teaching_rate = math.cos(math.pi / 2 * epoch / epochs)
1281
+ train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, teaching_rate)
1282
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
1283
+ logging_dict.update(train_loss_epoch_dict)
1284
+
1285
+ if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0:
1286
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1287
+
1288
+ val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, id2label)
1289
+ update = self._update_best_network(eval_net, val_score, epoch)
1290
+ logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update})
1291
+ logging_dict.update(val_score_dict)
1292
+ if not self.schedule_in_step and scheduler:
1293
+ scheduler.step()
1294
+
1295
+ except RuntimeError as e:
1296
+ if "out of memory" in str(e).lower():
1297
+ print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...")
1298
+ torch.cuda.empty_cache()
1299
+ gc.collect()
1300
+ if torch.cuda.is_available():
1301
+ torch.cuda.synchronize()
1302
+ print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared")
1303
+
1304
+ train_loader = reduce_batch_size(train_loader, ratio=0.5)
1305
+ if val_loader is not None:
1306
+ val_loader = reduce_batch_size(val_loader, ratio=0.5)
1307
+
1308
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')}
1309
+ else:
1310
+ raise
1311
+
1312
+ training_log[epoch] = logging_dict
1313
+ if self.is_early_stopping(epoch):
1314
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...')
1315
+ break
1316
+ if self.logging:
1317
+ if epoch % self.logging == 0:
1318
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict))
1319
+ else:
1320
+ print(f'{epoch}...', end=' ')
1321
+
1322
+ if self._at_time_limit(start_training_time):
1323
+ print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}')
1324
+ break
1325
+
1326
+ if self.logging_file:
1327
+ os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True)
1328
+ with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f:
1329
+ f.write(json.dumps(training_log))
1330
+
1331
+ if self.use_ema and self.ema_net is not None:
1332
+ self._save_state_dict(self.ema_net.module)
1333
+ else:
1334
+ self._save_state_dict(network)
1335
+ print(f'[Trainer CallBack] 📢 Kết thúc training.\n')
1336
+
1337
+ best_model, last_model = None, None
1338
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1339
+ if self.return_best :
1340
+ best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict()
1341
+ best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()}
1342
+ if self.return_last:
1343
+ last_model = eval_net.state_dict()
1344
+ last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()}
1345
+
1346
+ del network
1347
+ torch.cuda.empty_cache()
1348
+ gc.collect()
1349
+ return training_log, best_model, last_model
1350
+
1351
+ def _time_str_to_seconds(self, time_str):
1352
+ days, hours, minutes, seconds = map(int, time_str.split(":"))
1353
+ return days * 86400 + hours * 3600 + minutes * 60 + seconds
1354
+
1355
+ def _update_best_network(self, network, val_score, epoch):
1356
+ topk = max(1, self.topk)
1357
+ self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}])
1358
+ self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk]
1359
+ if val_score in [x[0] for x in self.best_stage]:
1360
+ return True
1361
+ return False
1362
+
1363
+ def is_early_stopping(self, epoch):
1364
+ if self.best_stage[0][1] is None:
1365
+ return False
1366
+ if not self.early_stopping:
1367
+ return False
1368
+ return epoch - self.best_stage[0][1] >= self.early_stopping
1369
+
1370
+ def _at_time_limit(self, start_training_time):
1371
+ return time.time() - start_training_time >= self.training_time
1372
+
1373
+ def _save_state_dict(self, network):
1374
+ if self.topk <= 0:
1375
+ return
1376
+
1377
+ if self.save_best:
1378
+ for r in range(self.topk):
1379
+ os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True)
1380
+
1381
+ for rank, (score, epoch, state_dict) in enumerate(self.best_stage):
1382
+ if state_dict is None:
1383
+ continue
1384
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()}
1385
+ torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth')
1386
+ if self.save_last:
1387
+ os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True)
1388
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()}
1389
+ torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth')
1390
+
1391
+ def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn, teaching_rate):
1392
+ network.train()
1393
+ total_loss = 0
1394
+ total_loss_dict = {}
1395
+ for batch_idx, batch in enumerate(train_loader):
1396
+ optimizer.zero_grad()
1397
+ with torch.autocast(device_type=self.device, dtype=torch.float16):
1398
+ loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn, teaching_rate)
1399
+
1400
+ for k, v in loss_dict.items():
1401
+ t = total_loss_dict.get(k, 0)
1402
+ total_loss_dict[k] = t + v
1403
+ self.grad_scaler.scale(loss).backward()
1404
+ self.grad_scaler.unscale_(optimizer)
1405
+ grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm)
1406
+ # print(grad_norm) # Bỏ cmt dòng này để biết nên chọn max_grad_norm bằng bao nhiêu...
1407
+ self.grad_scaler.step(optimizer)
1408
+ self.grad_scaler.update()
1409
+ if self.schedule_in_step and scheduler:
1410
+ scheduler.step()
1411
+ if self.use_ema and self.ema_net is not None:
1412
+ self.ema_net.update(network)
1413
+ total_loss += loss
1414
+ return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()}
1415
+
1416
+ def _eval_epoch(self, network, val_loader, eval_fn, id2label):
1417
+ network.eval()
1418
+ total_score = 0.0
1419
+ total_score_dict = {}
1420
+ object_lists = None # sẽ init sau
1421
+
1422
+ with torch.no_grad():
1423
+ for batch_idx, batch in enumerate(val_loader):
1424
+ score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, id2label)
1425
+ total_score += score
1426
+
1427
+ for k, v in score_dict.items():
1428
+ t = total_score_dict.get(k, 0)
1429
+ total_score_dict[k] = t + v
1430
+
1431
+ if objects:
1432
+ if object_lists is None:
1433
+ object_lists = [[] for _ in range(len(objects))]
1434
+
1435
+ for i, obj in enumerate(objects):
1436
+ object_lists[i].append(obj.detach())
1437
+
1438
+ if object_lists is not None:
1439
+ object_arrays = [
1440
+ torch.concat(obj_list, dim=0).cpu().numpy()
1441
+ for obj_list in object_lists
1442
+ ]
1443
+ else:
1444
+ object_arrays = []
1445
+
1446
+ return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays
1447
+
1448
+ def _cal_loss(self, network, batch, batch_idx, loss_fn, teaching_rate):
1449
+ # Bạn cần override _cal_loss để tính loss
1450
+ input_ids = batch['input_ids'].to(self.device)
1451
+ attention_mask = batch['attention_mask'].to(self.device)
1452
+ all_spans = batch['all_spans'].to(self.device)
1453
+ all_labels = batch['all_labels'].to(self.device)
1454
+ all_weights = batch['all_weights'].to(self.device)
1455
+ start_labels = batch['start_labels'].to(self.device)
1456
+ end_labels = batch['end_labels'].to(self.device)
1457
+
1458
+ choice = random.random()
1459
+ if choice < teaching_rate:
1460
+ start_logits, end_logits, logits, pred_spans = network(input_ids, attention_mask, all_spans)
1461
+ else:
1462
+ start_logits, end_logits, logits, pred_spans = network(input_ids, attention_mask)
1463
+
1464
+ align_labels = align(all_spans, pred_spans, all_labels, -100)
1465
+ align_weights = align(all_spans, pred_spans, all_weights, 0)
1466
+
1467
+ loss_dict = loss_fn(
1468
+ logits, align_labels, align_weights,
1469
+ start_logits, start_labels,
1470
+ end_logits, end_labels,
1471
+ )
1472
+ return loss_dict['total'], loss_dict
1473
+
1474
+ def _cal_val_score(self, network, batch, batch_idx, eval_fn, id2label):
1475
+ # Bạn cần override _cal_val_score để tính val score, list bên cạnh là để trả về y hay pred gì đó (nếu cần)
1476
+ input_ids = batch['input_ids'].to(self.device)
1477
+ attention_mask = batch['attention_mask'].to(self.device)
1478
+ all_spans = batch['all_spans'].to(self.device)
1479
+ all_labels = batch['all_labels'].to(self.device)
1480
+ gold_entities = batch['gold_entities']
1481
+
1482
+ B, _, _ = input_ids.shape
1483
+
1484
+ start_logits, end_logits, logits, pred_spans = network(input_ids, attention_mask)
1485
+
1486
+ gold_list, pred_list = extract_spans(all_spans, all_labels, pred_spans)
1487
+ gold_list = list_to_tuple(gold_list)
1488
+ pred_list = list_to_tuple(pred_list)
1489
+ span_score = eval_fn(pred_list, gold_list)['recall']
1490
+
1491
+ pred_ids = extract_entities(input_ids.reshape(B, -1), start_logits, end_logits, logits, pred_spans, id2label, alpha=0.5)
1492
+ pred_ids = list_to_tuple(pred_ids)
1493
+ gold_ids = list_to_tuple(gold_entities)
1494
+ score_dict = eval_fn(pred_ids, gold_ids)
1495
+
1496
+ score_dict.update({'span_recall': span_score})
1497
+ return score_dict['f1'] + score_dict['span_recall'], score_dict, []
1498
+
1499
+ # %% [code]
1500
+ class PhoBERTSpanAligner:
1501
+ def __init__(self, tokenizer, max_len):
1502
+ self.tokenizer = tokenizer
1503
+ self.max_len = max_len
1504
+
1505
+ # ===== 1. Extract discontinuous spans =====
1506
+ def extract_spans(self, sample):
1507
+ entity_spans = []
1508
+
1509
+ for event in sample["entities"]:
1510
+ entity_type = event["label"]
1511
+ spans = [tuple(event["offset"])]
1512
+ entity_spans.append({
1513
+ "spans": spans,
1514
+ "label": entity_type
1515
+ })
1516
+
1517
+ return entity_spans
1518
+
1519
+ # ===== 2. Word offsets =====
1520
+ def build_word_offsets(self, text, words):
1521
+ offsets = []
1522
+ pointer = 0
1523
+
1524
+ for word in words:
1525
+ start = text.find(word, pointer)
1526
+ end = start + len(word)
1527
+ offsets.append((start, end))
1528
+ pointer = end
1529
+
1530
+ return offsets
1531
+
1532
+ # ===== 3. Char → word =====
1533
+ def char_span_to_word_span(self, word_offsets, start, end):
1534
+ start_word = None
1535
+ end_word = None
1536
+
1537
+ for i, (w_start, w_end) in enumerate(word_offsets):
1538
+ if w_start <= start < w_end:
1539
+ start_word = i
1540
+ if w_start < end <= w_end:
1541
+ end_word = i
1542
+
1543
+ return start_word, end_word
1544
+
1545
+ # ===== 4. Word → subword =====
1546
+ def word_to_subword_map(self, words):
1547
+ mapping = []
1548
+ subword_index = 1 # <s>
1549
+
1550
+ for word in words:
1551
+ sub_tokens = self.tokenizer.tokenize(word)
1552
+ start = subword_index
1553
+ end = subword_index + len(sub_tokens) - 1
1554
+ mapping.append((start, end))
1555
+ subword_index += len(sub_tokens)
1556
+
1557
+ return mapping
1558
+
1559
+ # ===== 5. Span → subword =====
1560
+ def span_to_subword(self, word_offsets, word_subword_map, spans):
1561
+ sub_spans = []
1562
+
1563
+ for span_start, span_end in spans:
1564
+ w_start, w_end = self.char_span_to_word_span(
1565
+ word_offsets, span_start, span_end
1566
+ )
1567
+ if w_start is None or w_end is None:
1568
+ continue
1569
+
1570
+ sub_start = word_subword_map[w_start][0]
1571
+ sub_end = word_subword_map[w_end][1]
1572
+ sub_spans.append((sub_start, sub_end))
1573
+
1574
+ return sub_spans
1575
+
1576
+ def extract_valid_spans(self, sub_spans):
1577
+ valid_spans = []
1578
+ for s, e in sub_spans:
1579
+ if s < 0 or e < 0 or s >= self.max_len or e >= self.max_len or s > e:
1580
+ continue
1581
+ valid_spans.append((s, e))
1582
+ return valid_spans
1583
+
1584
+ def encode(self, sample):
1585
+ text = sample["text"]
1586
+ entities = self.extract_spans(sample)
1587
+
1588
+ # ===== 1. Word tokenize =====
1589
+ words = word_tokenize(text)
1590
+ sentence = " ".join(words)
1591
+
1592
+ # ===== 2. Mapping =====
1593
+ word_offsets = self.build_word_offsets(text, words)
1594
+ word_subword_map = self.word_to_subword_map(words)
1595
+
1596
+ # ===== 3. Tokenize FULL =====
1597
+ encoding = self.tokenizer(
1598
+ sentence,
1599
+ max_length=self.max_len,
1600
+ truncation=True,
1601
+ padding="max_length",
1602
+ return_tensors="pt"
1603
+ )
1604
+ input_ids = encoding["input_ids"][0]
1605
+ attention_mask = encoding["attention_mask"][0]
1606
+
1607
+ # ===== 5. Convert spans =====
1608
+ entities_gold_spans = []
1609
+
1610
+ for ent in entities:
1611
+ label = ent["label"]
1612
+
1613
+ sub_spans = self.span_to_subword(
1614
+ word_offsets,
1615
+ word_subword_map,
1616
+ ent["spans"]
1617
+ )
1618
+ valid_spans = self.extract_valid_spans(sub_spans)
1619
+ if len(valid_spans) == 0:
1620
+ continue
1621
+ entities_gold_spans.append((tuple(valid_spans), label))
1622
+
1623
+ return {
1624
+ "input_ids": input_ids,
1625
+ "attention_mask": attention_mask,
1626
+ "entities_gold_spans": entities_gold_spans,
1627
+ }
1628
+
1629
+ def generate_spans(attention_mask, max_span_len):
1630
+ seq_len = attention_mask.sum().item() - 2
1631
+ spans = []
1632
+ for i in range(1, seq_len+1):
1633
+ for j in range(i, min(i+max_span_len, seq_len+1)):
1634
+ spans.append((i, j))
1635
+ return spans
1636
+
1637
+ def match_gold_labels(
1638
+ gold_spans, # (N, 2)
1639
+ gold_labels, # (N,)
1640
+ pred_spans, # (M, 2)
1641
+ default_label=-100
1642
+ ):
1643
+ """
1644
+ Return:
1645
+ pred_labels: (M,)
1646
+ """
1647
+
1648
+ pred_labels = torch.full(
1649
+ (pred_spans.size(0),),
1650
+ default_label,
1651
+ dtype=gold_labels.dtype,
1652
+ device=gold_labels.device
1653
+ )
1654
+ if gold_spans.size(0) == 0:
1655
+ return pred_labels
1656
+
1657
+ # (M, N)
1658
+ matched = (pred_spans[:, None, :] == gold_spans[None, :, :]).all(dim=-1)
1659
+ has_match = matched.any(dim=1)
1660
+
1661
+ # lấy index gold đầu tiên match
1662
+ gold_idx = matched.float().argmax(dim=1)
1663
+
1664
+ pred_labels[has_match] = gold_labels[gold_idx[has_match]]
1665
+
1666
+ return pred_labels
1667
+
1668
+ class KLTNDataset(Dataset):
1669
+ def __init__(self, all_data, using_idxes, label2id, tokenizer, max_len, max_n_parts, max_span_len, weight_rate):
1670
+ super().__init__()
1671
+ self.tokenizer = tokenizer
1672
+ self.aligner = PhoBERTSpanAligner(tokenizer, max_len*max_n_parts)
1673
+ self.all_data = all_data
1674
+ self.using_idxes = using_idxes
1675
+ self.label2id = label2id
1676
+ self.max_len = max_len
1677
+ self.max_n_parts = max_n_parts
1678
+ self.max_span_len = max_span_len
1679
+ self.weight_rate = weight_rate
1680
+
1681
+ def __len__(self):
1682
+ return len(self.using_idxes)
1683
+
1684
+ def span_iou(self, span1, span2):
1685
+ s1, e1 = span1
1686
+ s2, e2 = span2
1687
+
1688
+ # intersection
1689
+ inter_left = max(s1, s2)
1690
+ inter_right = min(e1, e2)
1691
+ intersection = max(0, inter_right - inter_left + 1)
1692
+
1693
+ # lengths
1694
+ len1 = e1 - s1 + 1
1695
+ len2 = e2 - s2 + 1
1696
+
1697
+ # union
1698
+ union = len1 + len2 - intersection
1699
+ if union == 0:
1700
+ return 0.0
1701
+
1702
+ return intersection / union
1703
+
1704
+ def get_weights(self, spans, pos_spans):
1705
+ # spans: (N, 2), pos_spans: (K, 2)
1706
+ N, K = spans.size(0), pos_spans.size(0)
1707
+ device = spans.device
1708
+
1709
+ # ===== edge case =====
1710
+ if K == 0:
1711
+ weights = torch.ones(N, device=device, dtype=torch.float)
1712
+ return weights
1713
+
1714
+ # ===== IoU =====
1715
+ s1, e1 = spans[:, None, 0], spans[:, None, 1]
1716
+ s2, e2 = pos_spans[None, :, 0], pos_spans[None, :, 1]
1717
+
1718
+ inter_s = torch.maximum(s1, s2)
1719
+ inter_e = torch.minimum(e1, e2)
1720
+ inter = (inter_e - inter_s + 1).clamp(min=0)
1721
+
1722
+ len1 = (e1 - s1 + 1)
1723
+ len2 = (e2 - s2 + 1)
1724
+ union = len1 + len2 - inter
1725
+
1726
+ iou = inter / (union + 1e-8) # (N, K)
1727
+
1728
+ # ===== weights: IoU=0 -> 1, else 10*IoU =====
1729
+ max_iou, _ = iou.max(dim=1)
1730
+ if self.weight_rate is not None:
1731
+ weights = torch.where(max_iou > 0, 1 + self.weight_rate * max_iou, torch.ones_like(max_iou))
1732
+ else:
1733
+ weights = torch.ones_like(max_iou)
1734
+
1735
+ return weights
1736
+
1737
+ def to_span_label_tensors(self, data, label_map):
1738
+ if len(data) == 0:
1739
+ return (
1740
+ torch.zeros((0, 2), dtype=torch.long),
1741
+ torch.zeros((0,), dtype=torch.long)
1742
+ )
1743
+
1744
+ spans = torch.tensor([list(spans[0]) for spans, _ in data], dtype=torch.long)
1745
+ labels = torch.tensor([label_map[label] for _, label in data], dtype=torch.long)
1746
+ return spans, labels
1747
+
1748
+ def __getitem__(self, idx):
1749
+ ridx = self.using_idxes[idx]
1750
+ sample = self.all_data[ridx]
1751
+ result = self.aligner.encode(sample)
1752
+
1753
+ input_ids = result["input_ids"].squeeze(0)
1754
+ attention_mask = result["attention_mask"].squeeze(0)
1755
+ entities_gold_spans = result["entities_gold_spans"]
1756
+
1757
+ # Get all spans
1758
+ all_spans = torch.tensor(generate_spans(attention_mask, self.max_span_len))
1759
+ gold_spans = torch.tensor([spans[0] for spans, _ in entities_gold_spans], dtype=torch.long) if entities_gold_spans else torch.empty(0, 2, dtype=torch.long)
1760
+ gold_labels = torch.tensor([self.label2id[label] for _, label in entities_gold_spans], dtype=torch.long) if entities_gold_spans else torch.empty(0, dtype=torch.long)
1761
+ all_labels = match_gold_labels(
1762
+ gold_spans, # (N, 2)
1763
+ gold_labels, # (N,)
1764
+ all_spans, # (M, 2)
1765
+ default_label=0
1766
+ )
1767
+ all_weights = self.get_weights(all_spans, gold_spans)
1768
+
1769
+ # Get label
1770
+ gold_entities = []
1771
+ start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
1772
+ end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
1773
+ for spans, label in entities_gold_spans:
1774
+ s, e = spans[0]
1775
+
1776
+ start_labels[s] = self.label2id[f'{label}']
1777
+ end_labels[e] = self.label2id[f'{label}']
1778
+
1779
+ gold_entities.append((tuple(input_ids[s:e+1].tolist()), label))
1780
+
1781
+ input_ids = input_ids.reshape(self.max_n_parts, self.max_len)
1782
+ attention_mask = attention_mask.reshape(self.max_n_parts, self.max_len)
1783
+
1784
+ n_valid_parts = math.ceil(attention_mask.sum().item() / self.max_len)
1785
+ input_ids = input_ids[:n_valid_parts]
1786
+ attention_mask = attention_mask[:n_valid_parts]
1787
+ start_labels = start_labels[:n_valid_parts*self.max_len]
1788
+ end_labels = end_labels[:n_valid_parts*self.max_len]
1789
+
1790
+ return {
1791
+ "input_ids": input_ids,
1792
+ "attention_mask": attention_mask,
1793
+ "all_spans": all_spans,
1794
+ "all_labels": all_labels,
1795
+ "all_weights": all_weights,
1796
+ "start_labels": start_labels,
1797
+ "end_labels": end_labels,
1798
+ "gold_entities": gold_entities,
1799
+ }
1800
+
1801
+ def _pad_batch(tensor_list, pad_value=0):
1802
+ """
1803
+ tensor_list: list of tensors
1804
+ mỗi tensor shape: (Nk, n_parts_i, max_len_i)
1805
+
1806
+ return:
1807
+ padded tensor shape: (B, max_Nk, max_n_parts, max_len)
1808
+ """
1809
+
1810
+ # lấy max toàn batch
1811
+ max_Nk = max(t.size(0) for t in tensor_list)
1812
+ max_n_parts = max(t.size(1) for t in tensor_list)
1813
+ max_len = max(t.size(2) for t in tensor_list)
1814
+
1815
+ padded = []
1816
+
1817
+ for t in tensor_list:
1818
+ Nk, n_parts_i, max_len_i = t.shape
1819
+
1820
+ # pad chiều n_parts và max_len trước
1821
+ if n_parts_i < max_n_parts or max_len_i < max_len:
1822
+ new_t = t.new_full(
1823
+ (Nk, max_n_parts, max_len),
1824
+ pad_value
1825
+ )
1826
+ new_t[:, :n_parts_i, :max_len_i] = t
1827
+ t = new_t
1828
+
1829
+ # pad chiều Nk
1830
+ if Nk < max_Nk:
1831
+ pad_tensor = t.new_full(
1832
+ (max_Nk - Nk, max_n_parts, max_len),
1833
+ pad_value
1834
+ )
1835
+ t = torch.cat([t, pad_tensor], dim=0)
1836
+
1837
+ padded.append(t)
1838
+
1839
+ return torch.stack(padded) # (B, max_Nk, max_n_parts, max_len)
1840
+
1841
+ def collate_fn(batch):
1842
+ gold_entities = []
1843
+ for bidx, b in enumerate(batch):
1844
+ for entity in b['gold_entities']:
1845
+ gold_entities.append([bidx, entity])
1846
+
1847
+ input_ids = [b["input_ids"].unsqueeze(-1) for b in batch]
1848
+ attention_mask = [b["attention_mask"].unsqueeze(-1) for b in batch]
1849
+ all_spans = [b["all_spans"].unsqueeze(-1) for b in batch]
1850
+ all_labels = [b["all_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1851
+ all_weights = [b["all_weights"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1852
+ start_labels = [b["start_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1853
+ end_labels = [b["end_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1854
+
1855
+ # pad theo Nk
1856
+ input_ids = _pad_batch(input_ids, pad_value=0).squeeze(-1)
1857
+ attention_mask = _pad_batch(attention_mask, pad_value=0).squeeze(-1)
1858
+ all_spans = _pad_batch(all_spans, pad_value=0).squeeze(-1)
1859
+ all_labels = _pad_batch(all_labels, pad_value=-100).squeeze(-1).squeeze(-1)
1860
+ all_weights = _pad_batch(all_weights, pad_value=0).squeeze(-1).squeeze(-1)
1861
+ start_labels = _pad_batch(start_labels, pad_value=-100).squeeze(-1).squeeze(-1)
1862
+ end_labels = _pad_batch(end_labels, pad_value=-100).squeeze(-1).squeeze(-1)
1863
+
1864
+ return {
1865
+ "input_ids": input_ids,
1866
+ "attention_mask": attention_mask,
1867
+ "all_spans": all_spans,
1868
+ "all_labels": all_labels,
1869
+ "all_weights": all_weights,
1870
+ "start_labels": start_labels,
1871
+ "end_labels": end_labels,
1872
+ "gold_entities": gold_entities,
1873
+ }
1874
+
1875
+ # %% [code]
1876
+ def shift_bidx(spans, batch_idx):
1877
+ shifted = []
1878
+ for bidx, ent in spans:
1879
+ new_bidx = bidx + batch_idx * batch_size
1880
+ shifted.append((new_bidx, ent))
1881
+ return shifted
1882
+
1883
+ def refactor_entities(entities, save_dict):
1884
+ i, c = [], []
1885
+ for bidx, (ids, lb) in entities:
1886
+ if (bidx, ids) not in i:
1887
+ i.append((bidx, ids))
1888
+
1889
+ if (bidx, (ids, lb)) not in c:
1890
+ c.append((bidx, (ids, lb)))
1891
+
1892
+ save_dict['Ent-I'].extend(i)
1893
+ save_dict['Ent-C'].extend(c)
1894
+
1895
+ def test(
1896
+ network, state_dicts, test_loader, eval_fn, analyzer, device, id2label, tokenizer,
1897
+ alphas=[0.0, 0.2, 0.4, 0.5, 0.6, 0.8, 1.0], keep_neighbors=[0, 1, 2, 3]
1898
+ ):
1899
+ if torch.cuda.device_count() > 1:
1900
+ network = DataParallelProxy(network)
1901
+ network = network.to(device)
1902
+ network.eval()
1903
+
1904
+ eval_types = ['Ent-I', 'Ent-C']
1905
+
1906
+ all_pred = {(keep_neighbor, alpha): {eval_type: [] for eval_type in eval_types} for alpha in alphas for keep_neighbor in keep_neighbors}
1907
+ all_gold = {eval_type: [] for eval_type in eval_types}
1908
+
1909
+ list_input_ids = []
1910
+
1911
+ with torch.no_grad():
1912
+ for batch_idx, batch in enumerate(test_loader):
1913
+ input_ids = batch['input_ids'].to(device)
1914
+ attention_mask = batch['attention_mask'].to(device)
1915
+ all_spans = batch['all_spans'].to(device)
1916
+ gold_entities = batch['gold_entities']
1917
+
1918
+ B, _, _ = input_ids.shape
1919
+ list_input_ids.extend(input_ids.reshape(B, -1).tolist())
1920
+
1921
+ list_hidden_states = []
1922
+ list_start_logits = []
1923
+ list_end_logits = []
1924
+ for sd in state_dicts:
1925
+ if torch.cuda.device_count() > 1:
1926
+ network.module.load_state_dict(sd)
1927
+ else:
1928
+ network.load_state_dict(sd)
1929
+
1930
+ hidden_states, attention_mask = network.encode(input_ids, attention_mask)
1931
+ start_logits, end_logits = network.get_token_logits(hidden_states)
1932
+ list_hidden_states.append(hidden_states)
1933
+ list_start_logits.append(start_logits)
1934
+ list_end_logits.append(end_logits)
1935
+
1936
+ ensemble_start_logits = torch.stack(list_start_logits, dim=0).mean(dim=0)
1937
+ ensemble_end_logits = torch.stack(list_end_logits, dim=0).mean(dim=0)
1938
+
1939
+ for keep_neighbor in keep_neighbors:
1940
+ list_logits = []
1941
+ spans = filter_spans(ensemble_start_logits, ensemble_end_logits, attention_mask, network.max_span_len, network.topk_spans, keep_neighbor)
1942
+
1943
+ for sd, hidden_states in zip(state_dicts, list_hidden_states):
1944
+ if torch.cuda.device_count() > 1:
1945
+ network.module.load_state_dict(sd)
1946
+ else:
1947
+ network.load_state_dict(sd)
1948
+ span_reprs = get_span_reprs(hidden_states, spans)
1949
+ logits = network.get_logits(span_reprs)
1950
+ list_logits.append(logits)
1951
+
1952
+ ensemble_logits = torch.stack(list_logits, dim=0).mean(dim=0)
1953
+ for alpha in alphas:
1954
+ pred_entities = extract_entities(input_ids.reshape(B, -1), ensemble_start_logits, ensemble_end_logits, ensemble_logits, spans, id2label, alpha=alpha)
1955
+ pred_entities = shift_bidx(pred_entities, batch_idx)
1956
+ refactor_entities(pred_entities, all_pred[keep_neighbor, alpha])
1957
+
1958
+ gold_entities = shift_bidx(gold_entities, batch_idx)
1959
+ refactor_entities(gold_entities, all_gold)
1960
+
1961
+ # ===== GLOBAL EVAL =====
1962
+ final_score = {}
1963
+ max_key = -1
1964
+ max_scores = -1
1965
+ for key in all_pred.keys():
1966
+ final_score[key] = {}
1967
+ for eval_type in eval_types:
1968
+ score = eval_fn(list_to_tuple(all_pred[key][eval_type]), list_to_tuple(all_gold[eval_type]))
1969
+ final_score[key][eval_type] = score
1970
+ if max_scores < final_score[key]['Ent-C']['f1']:
1971
+ max_scores = final_score[key]['Ent-C']['f1']
1972
+ max_key = key
1973
+
1974
+ analyze_result = analyzer.analyze(list_to_tuple(all_pred[max_key]['Ent-I']), list_to_tuple(all_gold['Ent-I']))
1975
+
1976
+ # ===== PREDICT =====
1977
+ predictions = []
1978
+ for input_ids in list_input_ids:
1979
+ predictions.append([tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)])
1980
+ for bidx, (ids, lb) in all_pred[max_key]['Ent-C']:
1981
+ predictions[bidx].append((tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True), lb))
1982
+
1983
+ return final_score, analyze_result, predictions
1984
+
1985
+ # %% [code]
1986
+ with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f:
1987
+ data_train = json.load(f)
1988
+
1989
+ with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f:
1990
+ data_test = json.load(f)
1991
+
1992
+ print('Train:', len(data_train))
1993
+ print('Test:', len(data_test))
1994
+
1995
+ # %% [code]
1996
+ entity_types = ['O'] + sorted(list(set([e['label'] for d in data_train + data_test for e in d['entities']])))
1997
+ # bio_entity_type = ['O'] + [f'{prefix}-{ent}' for ent in entity_types for prefix in ['B', 'I']]
1998
+ label2id = {l: i for i, l in enumerate(entity_types)}
1999
+ id2label = {i: l for l, i in label2id.items()}
2000
+
2001
+ # %% [code]
2002
+ zero_entities_idxes = []
2003
+ for idx, d in enumerate(data_train):
2004
+ if len(d['entities']) == 0:
2005
+ zero_entities_idxes.append(idx)
2006
+
2007
+ n_zero_entities_samples = len(zero_entities_idxes)
2008
+ n_has_entities_samples = len(data_train) - n_zero_entities_samples
2009
+
2010
+ random.seed(42)
2011
+ k = min(int(n_has_entities_samples * zero_entities_rate), len(zero_entities_idxes))
2012
+ sampled_zero_entities_idxes = random.sample(zero_entities_idxes, k)
2013
+
2014
+ new_data_train = []
2015
+ for idx, d in enumerate(data_train):
2016
+ if len(d['entities']) == 0:
2017
+ if idx in sampled_zero_entities_idxes:
2018
+ new_data_train.append(d)
2019
+ else:
2020
+ new_data_train.append(d)
2021
+ data_train = new_data_train
2022
+
2023
+ print('Train:', len(data_train))
2024
+
2025
+ # %% [code]
2026
+ if debug_only:
2027
+ data_train = data_train[:10]
2028
+ data_test = data_test[:10]
2029
+
2030
+ print('Train:', len(data_train))
2031
+ print('Test:', len(data_test))
2032
+
2033
+ # %% [code]
2034
+ tokenizer = AutoTokenizer.from_pretrained(backbone_model_name)
2035
+
2036
+ # %% [code]
2037
+ print('Experiment name:', state_dict_save_name)
2038
+
2039
+ # %% [code]
2040
+ # trainset = KLTNDataset(data_train, np.array(range(len(data_train))), label2id, tokenizer, **train_memory_params)
2041
+ # train_loader = DataLoader(trainset, collate_fn=collate_fn, **train_loader_params)
2042
+ # for b in train_loader:
2043
+ # break
2044
+
2045
+ # %% [code]
2046
+ if not test_only:
2047
+ full_idxes = np.array(range(len(data_train)))
2048
+ training_logs, best_models, last_models = [], [], []
2049
+ start_training_time = time.time()
2050
+ for seed in SEEDS:
2051
+ kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed)
2052
+ for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)):
2053
+ if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx:
2054
+ continue
2055
+ set_seed(seed)
2056
+
2057
+ train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx]
2058
+
2059
+ trainset = KLTNDataset(data_train, train_idxes, label2id, tokenizer, **train_memory_params)
2060
+ valset = KLTNDataset(data_train, val_idxes, label2id, tokenizer, **val_memory_params)
2061
+
2062
+ generator = torch.Generator()
2063
+ generator.manual_seed(seed)
2064
+ train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params)
2065
+ val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params)
2066
+
2067
+ my_model = IEModel(
2068
+ num_labels=len(label2id),
2069
+ **model_params
2070
+ )
2071
+ total_params = sum(p.numel() for p in my_model.parameters())
2072
+ print(f"Total params: {total_params:,}")
2073
+
2074
+ # optimizer, scheduler = configure_optimizers(my_model, optim_params, scheduler_params)
2075
+ encoder_params = set(map(id, my_model.encoder.parameters()))
2076
+ other_params = [
2077
+ p for p in my_model.parameters()
2078
+ if id(p) not in encoder_params
2079
+ ]
2080
+ optimizer = optim.AdamW([
2081
+ {"params": my_model.encoder.parameters(), "lr": 2e-5},
2082
+ {"params": other_params}
2083
+ ], lr=5e-4)
2084
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
2085
+
2086
+ loss_fn = CustomLoss(
2087
+ **loss_func_params
2088
+ )
2089
+ eval_fn = CustomEvalFn(**eval_func_params)
2090
+ trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}'
2091
+ trainer = Trainer(**trainer_params)
2092
+
2093
+ print(f'Start Training Fold {fold_idx}...')
2094
+ training_log, best_model, last_model = trainer.fit(
2095
+ my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, eval_fn,
2096
+ start_epoch=1, start_training_time=start_training_time, id2label=id2label
2097
+ )
2098
+
2099
+ training_logs.append(training_log)
2100
+ best_models.append(best_model)
2101
+ last_models.append(last_model)
2102
+
2103
+ # %% [code]
2104
+ def load_all_state_dicts(folder):
2105
+ files = []
2106
+
2107
+ for file in os.listdir(folder):
2108
+ if file.endswith(".pt") or file.endswith(".pth"):
2109
+ m = re.search(r"f(\d+)", file) # tìm f<số>
2110
+ if m:
2111
+ fold = int(m.group(1))
2112
+ files.append((fold, file))
2113
+
2114
+ # sort theo fold
2115
+ files.sort(key=lambda x: x[0])
2116
+
2117
+ state_dicts = []
2118
+ for fold, file in files:
2119
+ path = os.path.join(folder, file)
2120
+ print(f"Loading fold {fold}: {file}")
2121
+ state_dict = torch.load(path, map_location="cpu")
2122
+ state_dicts.append(state_dict)
2123
+
2124
+ return state_dicts
2125
+
2126
+ if test_only:
2127
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"])
2128
+ get_ipython().system('rm -rf .cache .gitattributes')
2129
+
2130
+ best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s")
2131
+ last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts")
2132
+
2133
+ # %% [code]
2134
+ def dict_to_df(data, row_names):
2135
+ """
2136
+ Input:
2137
+ {
2138
+ model_name: {
2139
+ (level1, ..., level(n-1)): {
2140
+ leveln: {
2141
+ metric: value
2142
+ }
2143
+ }
2144
+ }
2145
+ }
2146
+
2147
+ Output:
2148
+ - level1 -> level(n-1): columns thường
2149
+ - MultiIndex columns:
2150
+ (model_name, leveln, metric)
2151
+ """
2152
+
2153
+ rows = {}
2154
+
2155
+ for model_name, model_data in data.items():
2156
+
2157
+ for upper_levels, last_level_dict in model_data.items():
2158
+
2159
+ # đảm bảo tuple
2160
+ if not isinstance(upper_levels, tuple):
2161
+ upper_levels = (upper_levels,)
2162
+
2163
+ # ===== tạo key row =====
2164
+ row_key = upper_levels
2165
+
2166
+ if row_key not in rows:
2167
+
2168
+ row = {}
2169
+
2170
+ for i, lv in enumerate(upper_levels):
2171
+ row[row_names[i]] = lv
2172
+
2173
+ rows[row_key] = row
2174
+
2175
+ # ===== add metrics =====
2176
+ for last_level, metrics in last_level_dict.items():
2177
+
2178
+ for metric, value in metrics.items():
2179
+
2180
+ rows[row_key][
2181
+ (model_name, last_level, metric)
2182
+ ] = value
2183
+
2184
+ # ===== dataframe =====
2185
+ df = pd.DataFrame(rows.values())
2186
+
2187
+ # ===== split columns =====
2188
+ normal_cols = [
2189
+ c for c in df.columns
2190
+ if not isinstance(c, tuple)
2191
+ ]
2192
+
2193
+ metric_cols = [
2194
+ c for c in df.columns
2195
+ if isinstance(c, tuple)
2196
+ ]
2197
+
2198
+ df = df[normal_cols + metric_cols]
2199
+
2200
+ # ===== build multi columns =====
2201
+ df.columns = pd.MultiIndex.from_tuples([
2202
+ ("", "", c) if not isinstance(c, tuple) else c
2203
+ for c in df.columns
2204
+ ])
2205
+
2206
+ return df
2207
+
2208
+ def dict_to_records(data):
2209
+ """
2210
+ Input:
2211
+ {
2212
+ model_name: {
2213
+ (level1, level2, ..., leveln): {
2214
+ metric: value
2215
+ }
2216
+ }
2217
+ }
2218
+
2219
+ Output:
2220
+ [
2221
+ {
2222
+ "model": model_name,
2223
+ "levels": [...],
2224
+ "metrics": {...}
2225
+ },
2226
+ ...
2227
+ ]
2228
+ """
2229
+
2230
+ records = []
2231
+
2232
+ for model_name, model_data in data.items():
2233
+
2234
+ for levels, metrics in model_data.items():
2235
+
2236
+ # ensure tuple/list compatible
2237
+ if not isinstance(levels, (tuple, list)):
2238
+ levels = [levels]
2239
+
2240
+ records.append(
2241
+ {
2242
+ "model": model_name,
2243
+ "levels": list(levels),
2244
+ "metrics": metrics
2245
+ }
2246
+ )
2247
+
2248
+ return records
2249
+
2250
+ # %% [code]
2251
+ os.makedirs(f'{checkpoints_dir}/results', exist_ok=True)
2252
+ testset = KLTNDataset(data_test, range(len(data_test)), label2id, tokenizer, **val_memory_params)
2253
+ generator = torch.Generator()
2254
+ test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params)
2255
+ eval_fn = CustomEvalFn(**eval_func_params)
2256
+ analyzer = SpanErrorAnalyzer()
2257
+ my_model = IEModel(
2258
+ num_labels=len(label2id),
2259
+ **model_params
2260
+ )
2261
+ total_params = sum(p.numel() for p in my_model.parameters())
2262
+ print(f"Total params: {total_params:,}")
2263
+
2264
+ # %% [code]
2265
+ start_time = time.time()
2266
+
2267
+ best_score, best_analyze_result, best_pred_test = test(my_model, best_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
2268
+ last_score, last_analyze_result, last_pred_test = test(my_model, last_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
2269
+
2270
+ result_test = {"Best model": best_score, "Last model": last_score}
2271
+ analyze_result = {"Best model": best_analyze_result, "Last model": last_analyze_result}
2272
+ analyze_result_sumary = {"Best model": best_analyze_result['summary'], "Last model": last_analyze_result['summary']}
2273
+ pred_test = {"Best model": best_pred_test, "Last model": last_pred_test}
2274
+
2275
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test_{my_model.keep_neighbor}.json", "w", encoding="utf-8") as f:
2276
+ json.dump(dict_to_records(result_test), f, ensure_ascii=False, indent=2)
2277
+
2278
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_error_analyze_result_{my_model.keep_neighbor}.json", "w", encoding="utf-8") as f:
2279
+ json.dump(analyze_result, f, ensure_ascii=False, indent=2)
2280
+
2281
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_pred_test_{my_model.keep_neighbor}.json", "w", encoding="utf-8") as f:
2282
+ json.dump(pred_test, f, ensure_ascii=False, indent=2)
2283
+
2284
+ print('Test:', time.time() - start_time, 's --> Done!')
2285
+ print(json.dumps(analyze_result_sumary, ensure_ascii=False, indent=4))
2286
+
2287
+ # %% [code]
2288
+ result_test_df = dict_to_df(result_test, row_names=['keep_neightbor', 'alpha'])
2289
+
2290
+ col = ("Best model", "Ent-C", "f1")
2291
+ sorted_df = result_test_df.sort_values(
2292
+ by=col,
2293
+ ascending=False
2294
+ ).reset_index(drop=True)
2295
+ sorted_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_.xlsx")
2296
+
2297
+ sorted_df
2298
+
2299
+ # %% [code]
2300
+ def get_avg_best_score(logs):
2301
+ return float(np.mean([list(log.values())[-1]['best_score'] for log in logs]))
2302
+
2303
+ def get_avg_log(logs, epochs):
2304
+ avg_log = {}
2305
+
2306
+ for epoch in range(1, epochs + 1):
2307
+ val_score = 0.0
2308
+ train_loss = 0.0
2309
+ n_eval = 0
2310
+
2311
+ for idx in range(len(logs)):
2312
+ log = logs[idx].get(epoch, logs[idx].get(str(epoch)))
2313
+ if log is None:
2314
+ continue
2315
+
2316
+ val_score += log.get('val_score', 0.0)
2317
+ train_loss += log.get('train_loss', 0.0)
2318
+ n_eval += 1
2319
+
2320
+ if n_eval == 0:
2321
+ continue
2322
+
2323
+ avg_log[epoch] = {
2324
+ 'train_loss': train_loss / n_eval,
2325
+ 'val_score': val_score / n_eval if val_score != 0 else float('inf')
2326
+ }
2327
+
2328
+ return avg_log
2329
+
2330
+ def parse_label_key(label: str):
2331
+ try:
2332
+ first = float(label.split('_', 1)[0]) # số đầu: trước dấu _
2333
+ last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0])
2334
+ return first, last
2335
+ except:
2336
+ return (0, 0)
2337
+
2338
+ def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)):
2339
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
2340
+
2341
+ # ===== Plot Train Loss =====
2342
+ for name, log in logs_dict.items():
2343
+ epochs = sorted(log.keys())
2344
+ train_loss = [log[e]['train_loss'] for e in epochs]
2345
+ axes[0].plot(epochs, train_loss, label=name)
2346
+
2347
+ axes[0].set_xlabel('Epoch')
2348
+ axes[0].set_ylabel('Train Loss')
2349
+ axes[0].set_title('Training Loss')
2350
+ axes[0].grid(True)
2351
+
2352
+ # ===== Plot Validation Score =====
2353
+ for name, log in logs_dict.items():
2354
+ epochs = sorted(log.keys())
2355
+ val_score = [log[e]['val_score'] for e in epochs]
2356
+ axes[1].plot(epochs, val_score, label=name)
2357
+
2358
+ axes[1].set_xlabel('Epoch')
2359
+ axes[1].set_ylabel('Validation Score')
2360
+ axes[1].set_title('Validation Score')
2361
+ axes[1].grid(True)
2362
+
2363
+ # ===== Shared Legend =====
2364
+ handles, labels = axes[0].get_legend_handles_labels()
2365
+ pairs = list(zip(handles, labels))
2366
+ pairs_sorted = sorted(
2367
+ pairs,
2368
+ key=lambda x: parse_label_key(x[1])
2369
+ )
2370
+ handles_sorted, labels_sorted = zip(*pairs_sorted)
2371
+
2372
+ axes[0].legend(
2373
+ handles_sorted,
2374
+ labels_sorted,
2375
+ loc='center left',
2376
+ bbox_to_anchor=(1.01, 0.5),
2377
+ borderaxespad=0.
2378
+ )
2379
+
2380
+ plt.tight_layout(rect=[0, 0, 1, 1])
2381
+
2382
+ if save_path is not None:
2383
+ os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None
2384
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
2385
+
2386
+ plt.show()
2387
+
2388
+ # %% [code]
2389
+ if not test_only:
2390
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*ent*.json"], ignore_patterns=["**/*crf*.json"])
2391
+ get_ipython().system('rm -rf .cache .gitattributes')
2392
+
2393
+ # %% [code]
2394
+ if not test_only:
2395
+ experiments = {}
2396
+ for experiment in os.listdir(pretrained_dir):
2397
+ if '.virtual_documents' in experiment:
2398
+ continue
2399
+ experiment_logs = []
2400
+ try:
2401
+ for seed in SEEDS:
2402
+ for fold_idx in range(nfolds):
2403
+ with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f:
2404
+ experiment_log = json.load(f)
2405
+ experiment_logs.append(experiment_log)
2406
+ except:
2407
+ pass
2408
+ experiments[experiment] = get_avg_log(experiment_logs, 1000)
2409
+ experiments[state_dict_save_name] = get_avg_log(training_logs, 1000)
2410
+
2411
+ # %% [code]
2412
+ if not test_only:
2413
+ score = get_avg_best_score(training_logs)
2414
+ state_dict_save_name, score
2415
+
2416
+ # %% [code]
2417
+ if not test_only:
2418
+ plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5))
2419
+
20_entities_top_50_25/lasts/20_entities_top_50_25_s26092004_f0_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4effa0c1eb04620dd7198437271b736dbf113d8f8beb330bf325f3a1ef1d9fbc
3
+ size 554299549
20_entities_top_50_25/logs/20_entities_top_50_25_log_plot.jpg ADDED

Git LFS Details

  • SHA256: 4940477d7c2566023524e891c97a4fd7d0852e908c6241733912c2086d051988
  • Pointer size: 131 Bytes
  • Size of remote file: 563 kB
20_entities_top_50_25/logs/20_entities_top_50_25_s26092004_f0_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 0.0005], "train_loss": 0.15836021304130554, "total": 0.15836021632305933, "span_loss": 0.013298947295374947, "start_loss": 0.06792396643496279, "end_loss": 0.0771371789189142}, "2": {"lr": [1.988303923565381e-05, 0.0004969282409784868], "train_loss": 0.1167537197470665, "total": 0.116753720517529, "span_loss": 0.00795133610955159, "start_loss": 0.05067283071706653, "end_loss": 0.05812967469933447}, "3": {"lr": [1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 0.1098012775182724, "total": 0.10980128229097523, "span_loss": 0.007565838455021082, "start_loss": 0.04854550476244803, "end_loss": 0.05368998971579793}, "4": {"lr": [1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 0.10013656318187714, "total": 0.10013656658857079, "span_loss": 0.006966773726674949, "start_loss": 0.04450445884162984, "end_loss": 0.04866513571571278}, "5": {"lr": [1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 0.09494887292385101, "total": 0.09494887683030412, "span_loss": 0.006612358178553999, "start_loss": 0.04236734133315793, "end_loss": 0.04596924288703317}, "6": {"lr": [1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 0.09057747572660446, "total": 0.09057747684689858, "span_loss": 0.006500215887556934, "start_loss": 0.040513803600398704, "end_loss": 0.043563458958173214}, "7": {"lr": [1.60839598967785e-05, 0.00039715242044697206], "train_loss": 0.08300333470106125, "total": 0.08300333465524996, "span_loss": 0.006016988738280158, "start_loss": 0.03738455775065286, "end_loss": 0.03960186279706278}, "8": {"lr": [1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 0.07877770066261292, "total": 0.0787776991799932, "span_loss": 0.0058678082742792325, "start_loss": 0.03547269715617794, "end_loss": 0.037437170294158745, "val_score": 1.6879098478651393, "best_score": 1.6879098478651393, "new_best_model": true, "precision": 0.6927880624134277, "recall": 0.7066723232711557, "f1": 0.698129392068663, "span_recall": 0.9897804557964766}, "9": {"lr": [1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 0.07042290270328522, "total": 0.07042290659307968, "span_loss": 0.005389453446551899, "start_loss": 0.03172466646965287, "end_loss": 0.033308669933065346, "val_score": 1.6880608116033966, "best_score": 1.6880608116033966, "new_best_model": true, "precision": 0.6926097764085044, "recall": 0.7085987154401036, "f1": 0.6989633038308182, "span_recall": 0.9890975077725791}, "10": {"lr": [1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 0.06412826478481293, "total": 0.06412826802492008, "span_loss": 0.005015367366136153, "start_loss": 0.02899997043769542, "end_loss": 0.030112998454913272, "val_score": 1.6844831534157307, "best_score": 1.6880608116033966, "new_best_model": false, "precision": 0.6875223989713674, "recall": 0.7089024167935534, "f1": 0.6965351773006304, "span_recall": 0.9879479761150991}, "11": {"lr": [1.0500000000000003e-05, 0.0002505], "train_loss": 0.05869597569108009, "total": 0.05869597706333628, "span_loss": 0.004821416450786218, "start_loss": 0.026491579894848946, "end_loss": 0.027382992978466497, "val_score": 1.6814694558381817, "best_score": 1.6880608116033966, "new_best_model": false, "precision": 0.6850706713058013, "recall": 0.7074393355117581, "f1": 0.6945111173782986, "span_recall": 0.9869583384598831}, "12": {"lr": [9.013872582117811e-06, 0.00021146960097246246], "train_loss": 0.052141811698675156, "total": 0.05214181114269276, "span_loss": 0.004431201333503872, "start_loss": 0.023577266328490734, "end_loss": 0.024133408249524, "val_score": 1.6788781985189072, "best_score": 1.6880608116033966, "new_best_model": false, "precision": 0.6824626256370582, "recall": 0.7058162357753008, "f1": 0.6924016210918472, "span_recall": 0.9864765774270605}, "13": {"lr": [7.564338553438001e-06, 0.00017340025990345064], "train_loss": 0.04715990647673607, "total": 0.04715990607068151, "span_loss": 0.004110376017929256, "start_loss": 0.021314065505699982, "end_loss": 0.021735484004353864, "val_score": 1.6761510530355455, "best_score": 1.6880608116033966, "new_best_model": false, "precision": 0.6787482509223736, "recall": 0.7043477681110263, "f1": 0.6897907039476386, "span_recall": 0.9863603490879074}, "14": {"lr": [6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 0.043062418699264526, "total": 0.043062417982942636, "span_loss": 0.0039032074644153782, "start_loss": 0.019475716195871472, "end_loss": 0.019683522309185474, "val_score": 1.6738327289989086, "best_score": 1.6880608116033966, "new_best_model": false, "precision": 0.6776761968384113, "recall": 0.7027311676897512, "f1": 0.6884340121848059, "span_recall": 0.985398716814104}, "15": {"lr": [4.916040103221507e-06, 0.00010384757955302797], "train_loss": 0.039497774094343185, "total": 0.03949777636200173, "span_loss": 0.0037452027143347255, "start_loss": 0.017812720352607843, "end_loss": 0.01793987088659211, "val_score": 1.6726570404456254, "best_score": 1.6880608116033966, "new_best_model": false, "precision": 0.6782996836928802, "recall": 0.7014953495122939, "f1": 0.6881792029631572, "span_recall": 0.9844778374824673}, "16": {"lr": [3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 0.036539915949106216, "total": 0.03653991682160294, "span_loss": 0.0035204935367044755, "start_loss": 0.016551774197147172, "end_loss": 0.016467658949671275, "val_score": 1.6698516693773606, "best_score": 1.6880608116033966, "new_best_model": false, "precision": 0.6760791197776373, "recall": 0.6998378552184006, "f1": 0.6861856178380178, "span_recall": 0.9836660515393411}, "17": {"lr": [2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 0.03419886529445648, "total": 0.03419886546937229, "span_loss": 0.0034353670426225317, "start_loss": 0.015404574893718995, "end_loss": 0.015358882752658936, "val_score": 1.669118306472923, "best_score": 1.6880608116033966, "new_best_model": false, "precision": 0.6767968954388489, "recall": 0.699000197096468, "f1": 0.6862313053001838, "span_recall": 0.9828870011727384}, "18": {"lr": [2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 0.03241745010018349, "total": 0.03241744875707994, "span_loss": 0.003389500052386519, "start_loss": 0.0145352450477937, "end_loss": 0.014492702324207832, "val_score": 1.6678282196605587, "best_score": 1.6880608116033966, "new_best_model": false, "precision": 0.6749447392504405, "recall": 0.6987369641707758, "f1": 0.6851534395678112, "span_recall": 0.9826747800927468}}
20_entities_top_50_25/r1s/20_entities_top_50_25_s26092004_f0_r1_vs1.68806_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60048325b4ca09fe19685dc199077bc00fc20ae249b7f5df9d9f87866f5a3da9
3
+ size 554301285
20_entities_top_50_25/results/20_entities_top_50_25_error_analyze_result_2.json ADDED
The diff for this file is too large to render. See raw diff
 
20_entities_top_50_25/results/20_entities_top_50_25_pred_test_2.json ADDED
The diff for this file is too large to render. See raw diff
 
20_entities_top_50_25/results/20_entities_top_50_25_test_2.json ADDED
@@ -0,0 +1,1066 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "model": "Best model",
4
+ "levels": [
5
+ 0,
6
+ 0.0
7
+ ],
8
+ "metrics": {
9
+ "Ent-I": {
10
+ "precision": 0.7274748732538223,
11
+ "recall": 0.7602007621519098,
12
+ "f1": 0.7434778606503211
13
+ },
14
+ "Ent-C": {
15
+ "precision": 0.6501823356749531,
16
+ "recall": 0.6788002600049412,
17
+ "f1": 0.6641831678165434
18
+ }
19
+ }
20
+ },
21
+ {
22
+ "model": "Best model",
23
+ "levels": [
24
+ 1,
25
+ 0.0
26
+ ],
27
+ "metrics": {
28
+ "Ent-I": {
29
+ "precision": 0.7274748732538223,
30
+ "recall": 0.7602007621519098,
31
+ "f1": 0.7434778606503211
32
+ },
33
+ "Ent-C": {
34
+ "precision": 0.6501823356749531,
35
+ "recall": 0.6788002600049412,
36
+ "f1": 0.6641831678165434
37
+ }
38
+ }
39
+ },
40
+ {
41
+ "model": "Best model",
42
+ "levels": [
43
+ 2,
44
+ 0.0
45
+ ],
46
+ "metrics": {
47
+ "Ent-I": {
48
+ "precision": 0.7274748732538223,
49
+ "recall": 0.7602007621519098,
50
+ "f1": 0.7434778606503211
51
+ },
52
+ "Ent-C": {
53
+ "precision": 0.6501823356749531,
54
+ "recall": 0.6788002600049412,
55
+ "f1": 0.6641831678165434
56
+ }
57
+ }
58
+ },
59
+ {
60
+ "model": "Best model",
61
+ "levels": [
62
+ 3,
63
+ 0.0
64
+ ],
65
+ "metrics": {
66
+ "Ent-I": {
67
+ "precision": 0.7274748732538223,
68
+ "recall": 0.7602007621519098,
69
+ "f1": 0.7434778606503211
70
+ },
71
+ "Ent-C": {
72
+ "precision": 0.6501823356749531,
73
+ "recall": 0.6788002600049412,
74
+ "f1": 0.6641831678165434
75
+ }
76
+ }
77
+ },
78
+ {
79
+ "model": "Best model",
80
+ "levels": [
81
+ 0,
82
+ 0.2
83
+ ],
84
+ "metrics": {
85
+ "Ent-I": {
86
+ "precision": 0.707413647851131,
87
+ "recall": 0.7804628682955846,
88
+ "f1": 0.7421450301433067
89
+ },
90
+ "Ent-C": {
91
+ "precision": 0.6323504633524579,
92
+ "recall": 0.6970006500132816,
93
+ "f1": 0.6631034889818218
94
+ }
95
+ }
96
+ },
97
+ {
98
+ "model": "Best model",
99
+ "levels": [
100
+ 1,
101
+ 0.2
102
+ ],
103
+ "metrics": {
104
+ "Ent-I": {
105
+ "precision": 0.707413647851131,
106
+ "recall": 0.7804628682955846,
107
+ "f1": 0.7421450301433067
108
+ },
109
+ "Ent-C": {
110
+ "precision": 0.6323504633524579,
111
+ "recall": 0.6970006500132816,
112
+ "f1": 0.6631034889818218
113
+ }
114
+ }
115
+ },
116
+ {
117
+ "model": "Best model",
118
+ "levels": [
119
+ 2,
120
+ 0.2
121
+ ],
122
+ "metrics": {
123
+ "Ent-I": {
124
+ "precision": 0.707413647851131,
125
+ "recall": 0.7804628682955846,
126
+ "f1": 0.7421450301433067
127
+ },
128
+ "Ent-C": {
129
+ "precision": 0.6323504633524579,
130
+ "recall": 0.6970006500132816,
131
+ "f1": 0.6631034889818218
132
+ }
133
+ }
134
+ },
135
+ {
136
+ "model": "Best model",
137
+ "levels": [
138
+ 3,
139
+ 0.2
140
+ ],
141
+ "metrics": {
142
+ "Ent-I": {
143
+ "precision": 0.707413647851131,
144
+ "recall": 0.7804628682955846,
145
+ "f1": 0.7421450301433067
146
+ },
147
+ "Ent-C": {
148
+ "precision": 0.6323504633524579,
149
+ "recall": 0.6970006500132816,
150
+ "f1": 0.6631034889818218
151
+ }
152
+ }
153
+ },
154
+ {
155
+ "model": "Best model",
156
+ "levels": [
157
+ 0,
158
+ 0.4
159
+ ],
160
+ "metrics": {
161
+ "Ent-I": {
162
+ "precision": 0.6903215329646879,
163
+ "recall": 0.7902221396033179,
164
+ "f1": 0.7369014034728304
165
+ },
166
+ "Ent-C": {
167
+ "precision": 0.6169210782716653,
168
+ "recall": 0.7055436902212782,
169
+ "f1": 0.6582629363257605
170
+ }
171
+ }
172
+ },
173
+ {
174
+ "model": "Best model",
175
+ "levels": [
176
+ 1,
177
+ 0.4
178
+ ],
179
+ "metrics": {
180
+ "Ent-I": {
181
+ "precision": 0.6903215329646879,
182
+ "recall": 0.7902221396033179,
183
+ "f1": 0.7369014034728304
184
+ },
185
+ "Ent-C": {
186
+ "precision": 0.6169210782716653,
187
+ "recall": 0.7055436902212782,
188
+ "f1": 0.6582629363257605
189
+ }
190
+ }
191
+ },
192
+ {
193
+ "model": "Best model",
194
+ "levels": [
195
+ 2,
196
+ 0.4
197
+ ],
198
+ "metrics": {
199
+ "Ent-I": {
200
+ "precision": 0.6903215329646879,
201
+ "recall": 0.7902221396033179,
202
+ "f1": 0.7369014034728304
203
+ },
204
+ "Ent-C": {
205
+ "precision": 0.6169210782716653,
206
+ "recall": 0.7055436902212782,
207
+ "f1": 0.6582629363257605
208
+ }
209
+ }
210
+ },
211
+ {
212
+ "model": "Best model",
213
+ "levels": [
214
+ 3,
215
+ 0.4
216
+ ],
217
+ "metrics": {
218
+ "Ent-I": {
219
+ "precision": 0.6903215329646879,
220
+ "recall": 0.7902221396033179,
221
+ "f1": 0.7369014034728304
222
+ },
223
+ "Ent-C": {
224
+ "precision": 0.6169210782716653,
225
+ "recall": 0.7055436902212782,
226
+ "f1": 0.6582629363257605
227
+ }
228
+ }
229
+ },
230
+ {
231
+ "model": "Best model",
232
+ "levels": [
233
+ 0,
234
+ 0.5
235
+ ],
236
+ "metrics": {
237
+ "Ent-I": {
238
+ "precision": 0.6829639113381748,
239
+ "recall": 0.7932893391571769,
240
+ "f1": 0.7340041229942705
241
+ },
242
+ "Ent-C": {
243
+ "precision": 0.6105465311669916,
244
+ "recall": 0.7085151824675378,
245
+ "f1": 0.6558927140161004
246
+ }
247
+ }
248
+ },
249
+ {
250
+ "model": "Best model",
251
+ "levels": [
252
+ 1,
253
+ 0.5
254
+ ],
255
+ "metrics": {
256
+ "Ent-I": {
257
+ "precision": 0.6829092654819308,
258
+ "recall": 0.7932893391571769,
259
+ "f1": 0.7339725624263866
260
+ },
261
+ "Ent-C": {
262
+ "precision": 0.6104976796282521,
263
+ "recall": 0.7085151824675378,
264
+ "f1": 0.6558645241885204
265
+ }
266
+ }
267
+ },
268
+ {
269
+ "model": "Best model",
270
+ "levels": [
271
+ 2,
272
+ 0.5
273
+ ],
274
+ "metrics": {
275
+ "Ent-I": {
276
+ "precision": 0.6829092654819308,
277
+ "recall": 0.7932893391571769,
278
+ "f1": 0.7339725624263866
279
+ },
280
+ "Ent-C": {
281
+ "precision": 0.6104976796282521,
282
+ "recall": 0.7085151824675378,
283
+ "f1": 0.6558645241885204
284
+ }
285
+ }
286
+ },
287
+ {
288
+ "model": "Best model",
289
+ "levels": [
290
+ 3,
291
+ 0.5
292
+ ],
293
+ "metrics": {
294
+ "Ent-I": {
295
+ "precision": 0.6829092654819308,
296
+ "recall": 0.7932893391571769,
297
+ "f1": 0.7339725624263866
298
+ },
299
+ "Ent-C": {
300
+ "precision": 0.6104976796282521,
301
+ "recall": 0.7085151824675378,
302
+ "f1": 0.6558645241885204
303
+ }
304
+ }
305
+ },
306
+ {
307
+ "model": "Best model",
308
+ "levels": [
309
+ 0,
310
+ 0.6
311
+ ],
312
+ "metrics": {
313
+ "Ent-I": {
314
+ "precision": 0.6757482429118883,
315
+ "recall": 0.7953341388597496,
316
+ "f1": 0.7306805517739423
317
+ },
318
+ "Ent-C": {
319
+ "precision": 0.6041222459128136,
320
+ "recall": 0.7103723651214501,
321
+ "f1": 0.6529532213892284
322
+ }
323
+ }
324
+ },
325
+ {
326
+ "model": "Best model",
327
+ "levels": [
328
+ 1,
329
+ 0.6
330
+ ],
331
+ "metrics": {
332
+ "Ent-I": {
333
+ "precision": 0.6756948831327576,
334
+ "recall": 0.7953341388597496,
335
+ "f1": 0.7306493567709114
336
+ },
337
+ "Ent-C": {
338
+ "precision": 0.604074542008367,
339
+ "recall": 0.7103723651214501,
340
+ "f1": 0.6529253567015821
341
+ }
342
+ }
343
+ },
344
+ {
345
+ "model": "Best model",
346
+ "levels": [
347
+ 2,
348
+ 0.6
349
+ ],
350
+ "metrics": {
351
+ "Ent-I": {
352
+ "precision": 0.6756948831327576,
353
+ "recall": 0.7953341388597496,
354
+ "f1": 0.7306493567709114
355
+ },
356
+ "Ent-C": {
357
+ "precision": 0.604074542008367,
358
+ "recall": 0.7103723651214501,
359
+ "f1": 0.6529253567015821
360
+ }
361
+ }
362
+ },
363
+ {
364
+ "model": "Best model",
365
+ "levels": [
366
+ 3,
367
+ 0.6
368
+ ],
369
+ "metrics": {
370
+ "Ent-I": {
371
+ "precision": 0.6756948831327576,
372
+ "recall": 0.7953341388597496,
373
+ "f1": 0.7306493567709114
374
+ },
375
+ "Ent-C": {
376
+ "precision": 0.604074542008367,
377
+ "recall": 0.7103723651214501,
378
+ "f1": 0.6529253567015821
379
+ }
380
+ }
381
+ },
382
+ {
383
+ "model": "Best model",
384
+ "levels": [
385
+ 0,
386
+ 0.8
387
+ ],
388
+ "metrics": {
389
+ "Ent-I": {
390
+ "precision": 0.662070559235355,
391
+ "recall": 0.7988660656187387,
392
+ "f1": 0.7240638508201541
393
+ },
394
+ "Ent-C": {
395
+ "precision": 0.5918194423042737,
396
+ "recall": 0.7134367165004053,
397
+ "f1": 0.6469622282119929
398
+ }
399
+ }
400
+ },
401
+ {
402
+ "model": "Best model",
403
+ "levels": [
404
+ 1,
405
+ 0.8
406
+ ],
407
+ "metrics": {
408
+ "Ent-I": {
409
+ "precision": 0.6619685767092868,
410
+ "recall": 0.7988660656187387,
411
+ "f1": 0.7240028590453211
412
+ },
413
+ "Ent-C": {
414
+ "precision": 0.5917282809607273,
415
+ "recall": 0.7134367165004053,
416
+ "f1": 0.6469077540628477
417
+ }
418
+ }
419
+ },
420
+ {
421
+ "model": "Best model",
422
+ "levels": [
423
+ 2,
424
+ 0.8
425
+ ],
426
+ "metrics": {
427
+ "Ent-I": {
428
+ "precision": 0.6619175972270605,
429
+ "recall": 0.7988660656187387,
430
+ "f1": 0.7239723670109903
431
+ },
432
+ "Ent-C": {
433
+ "precision": 0.5916827108197215,
434
+ "recall": 0.7134367165004053,
435
+ "f1": 0.6468805204281683
436
+ }
437
+ }
438
+ },
439
+ {
440
+ "model": "Best model",
441
+ "levels": [
442
+ 3,
443
+ 0.8
444
+ ],
445
+ "metrics": {
446
+ "Ent-I": {
447
+ "precision": 0.6619175972270605,
448
+ "recall": 0.7988660656187387,
449
+ "f1": 0.7239723670109903
450
+ },
451
+ "Ent-C": {
452
+ "precision": 0.5916827108197215,
453
+ "recall": 0.7134367165004053,
454
+ "f1": 0.6468805204281683
455
+ }
456
+ }
457
+ },
458
+ {
459
+ "model": "Best model",
460
+ "levels": [
461
+ 0,
462
+ 1.0
463
+ ],
464
+ "metrics": {
465
+ "Ent-I": {
466
+ "precision": 0.6443930117958455,
467
+ "recall": 0.8022121014956759,
468
+ "f1": 0.714693821912725
469
+ },
470
+ "Ent-C": {
471
+ "precision": 0.5754815589065436,
472
+ "recall": 0.7157581948177957,
473
+ "f1": 0.6380002433720193
474
+ }
475
+ }
476
+ },
477
+ {
478
+ "model": "Best model",
479
+ "levels": [
480
+ 1,
481
+ 1.0
482
+ ],
483
+ "metrics": {
484
+ "Ent-I": {
485
+ "precision": 0.6443233559747373,
486
+ "recall": 0.8023050469367019,
487
+ "f1": 0.7146878572878846
488
+ },
489
+ "Ent-C": {
490
+ "precision": 0.5754273344774387,
491
+ "recall": 0.7158510539504913,
492
+ "f1": 0.6380038020601749
493
+ }
494
+ }
495
+ },
496
+ {
497
+ "model": "Best model",
498
+ "levels": [
499
+ 2,
500
+ 1.0
501
+ ],
502
+ "metrics": {
503
+ "Ent-I": {
504
+ "precision": 0.6442752649644392,
505
+ "recall": 0.8023050469367019,
506
+ "f1": 0.7146582721630242
507
+ },
508
+ "Ent-C": {
509
+ "precision": 0.5753843857287838,
510
+ "recall": 0.7158510539504913,
511
+ "f1": 0.6379774022668191
512
+ }
513
+ }
514
+ },
515
+ {
516
+ "model": "Best model",
517
+ "levels": [
518
+ 3,
519
+ 1.0
520
+ ],
521
+ "metrics": {
522
+ "Ent-I": {
523
+ "precision": 0.6443018135676958,
524
+ "recall": 0.802397992377728,
525
+ "f1": 0.7147114777977568
526
+ },
527
+ "Ent-C": {
528
+ "precision": 0.5754160758261248,
529
+ "recall": 0.715943913083187,
530
+ "f1": 0.6380337587131018
531
+ }
532
+ }
533
+ },
534
+ {
535
+ "model": "Last model",
536
+ "levels": [
537
+ 0,
538
+ 0.0
539
+ ],
540
+ "metrics": {
541
+ "Ent-I": {
542
+ "precision": 0.7241286863264306,
543
+ "recall": 0.7531369086339315,
544
+ "f1": 0.7383479835200943
545
+ },
546
+ "Ent-C": {
547
+ "precision": 0.6447145027248282,
548
+ "recall": 0.6699786423988578,
549
+ "f1": 0.6571038201378601
550
+ }
551
+ }
552
+ },
553
+ {
554
+ "model": "Last model",
555
+ "levels": [
556
+ 1,
557
+ 0.0
558
+ ],
559
+ "metrics": {
560
+ "Ent-I": {
561
+ "precision": 0.7241286863264306,
562
+ "recall": 0.7531369086339315,
563
+ "f1": 0.7383479835200943
564
+ },
565
+ "Ent-C": {
566
+ "precision": 0.6447145027248282,
567
+ "recall": 0.6699786423988578,
568
+ "f1": 0.6571038201378601
569
+ }
570
+ }
571
+ },
572
+ {
573
+ "model": "Last model",
574
+ "levels": [
575
+ 2,
576
+ 0.0
577
+ ],
578
+ "metrics": {
579
+ "Ent-I": {
580
+ "precision": 0.7240639799832685,
581
+ "recall": 0.7531369086339315,
582
+ "f1": 0.7383143457985305
583
+ },
584
+ "Ent-C": {
585
+ "precision": 0.6446568977835555,
586
+ "recall": 0.6699786423988578,
587
+ "f1": 0.657073898739702
588
+ }
589
+ }
590
+ },
591
+ {
592
+ "model": "Last model",
593
+ "levels": [
594
+ 3,
595
+ 0.0
596
+ ],
597
+ "metrics": {
598
+ "Ent-I": {
599
+ "precision": 0.72399928520307,
600
+ "recall": 0.7531369086339315,
601
+ "f1": 0.7382807111417676
602
+ },
603
+ "Ent-C": {
604
+ "precision": 0.6445993031353126,
605
+ "recall": 0.6699786423988578,
606
+ "f1": 0.6570439800663778
607
+ }
608
+ }
609
+ },
610
+ {
611
+ "model": "Last model",
612
+ "levels": [
613
+ 0,
614
+ 0.2
615
+ ],
616
+ "metrics": {
617
+ "Ent-I": {
618
+ "precision": 0.6997393424697723,
619
+ "recall": 0.7734919602186323,
620
+ "f1": 0.734769551783911
621
+ },
622
+ "Ent-C": {
623
+ "precision": 0.6235076509159042,
624
+ "recall": 0.6886433280706763,
625
+ "f1": 0.6544588045250469
626
+ }
627
+ }
628
+ },
629
+ {
630
+ "model": "Last model",
631
+ "levels": [
632
+ 1,
633
+ 0.2
634
+ ],
635
+ "metrics": {
636
+ "Ent-I": {
637
+ "precision": 0.6996805111815203,
638
+ "recall": 0.7734919602186323,
639
+ "f1": 0.7347371159230409
640
+ },
641
+ "Ent-C": {
642
+ "precision": 0.6234552332907747,
643
+ "recall": 0.6886433280706763,
644
+ "f1": 0.6544299279450505
645
+ }
646
+ }
647
+ },
648
+ {
649
+ "model": "Last model",
650
+ "levels": [
651
+ 2,
652
+ 0.2
653
+ ],
654
+ "metrics": {
655
+ "Ent-I": {
656
+ "precision": 0.6996805111815203,
657
+ "recall": 0.7734919602186323,
658
+ "f1": 0.7347371159230409
659
+ },
660
+ "Ent-C": {
661
+ "precision": 0.6234552332907747,
662
+ "recall": 0.6886433280706763,
663
+ "f1": 0.6544299279450505
664
+ }
665
+ }
666
+ },
667
+ {
668
+ "model": "Last model",
669
+ "levels": [
670
+ 3,
671
+ 0.2
672
+ ],
673
+ "metrics": {
674
+ "Ent-I": {
675
+ "precision": 0.699621689785036,
676
+ "recall": 0.7734919602186323,
677
+ "f1": 0.7347046829257585
678
+ },
679
+ "Ent-C": {
680
+ "precision": 0.6234028244782923,
681
+ "recall": 0.6886433280706763,
682
+ "f1": 0.6544010539131748
683
+ }
684
+ }
685
+ },
686
+ {
687
+ "model": "Last model",
688
+ "levels": [
689
+ 0,
690
+ 0.4
691
+ ],
692
+ "metrics": {
693
+ "Ent-I": {
694
+ "precision": 0.6784908094157702,
695
+ "recall": 0.7822288316750792,
696
+ "f1": 0.7266761595974007
697
+ },
698
+ "Ent-C": {
699
+ "precision": 0.604272470777425,
700
+ "recall": 0.6960720586863255,
701
+ "f1": 0.646931901471241
702
+ }
703
+ }
704
+ },
705
+ {
706
+ "model": "Last model",
707
+ "levels": [
708
+ 1,
709
+ 0.4
710
+ ],
711
+ "metrics": {
712
+ "Ent-I": {
713
+ "precision": 0.678272082526855,
714
+ "recall": 0.7822288316750792,
715
+ "f1": 0.7265506921369767
716
+ },
717
+ "Ent-C": {
718
+ "precision": 0.604077685550323,
719
+ "recall": 0.6960720586863255,
720
+ "f1": 0.6468202556164152
721
+ }
722
+ }
723
+ },
724
+ {
725
+ "model": "Last model",
726
+ "levels": [
727
+ 2,
728
+ 0.4
729
+ ],
730
+ "metrics": {
731
+ "Ent-I": {
732
+ "precision": 0.678272082526855,
733
+ "recall": 0.7822288316750792,
734
+ "f1": 0.7265506921369767
735
+ },
736
+ "Ent-C": {
737
+ "precision": 0.604077685550323,
738
+ "recall": 0.6960720586863255,
739
+ "f1": 0.6468202556164152
740
+ }
741
+ }
742
+ },
743
+ {
744
+ "model": "Last model",
745
+ "levels": [
746
+ 3,
747
+ 0.4
748
+ ],
749
+ "metrics": {
750
+ "Ent-I": {
751
+ "precision": 0.678272082526855,
752
+ "recall": 0.7822288316750792,
753
+ "f1": 0.7265506921369767
754
+ },
755
+ "Ent-C": {
756
+ "precision": 0.604077685550323,
757
+ "recall": 0.6960720586863255,
758
+ "f1": 0.6468202556164152
759
+ }
760
+ }
761
+ },
762
+ {
763
+ "model": "Last model",
764
+ "levels": [
765
+ 0,
766
+ 0.5
767
+ ],
768
+ "metrics": {
769
+ "Ent-I": {
770
+ "precision": 0.669253920481016,
771
+ "recall": 0.7853889766699642,
772
+ "f1": 0.7226854771779272
773
+ },
774
+ "Ent-C": {
775
+ "precision": 0.5957868060500547,
776
+ "recall": 0.698579255269107,
777
+ "f1": 0.6431013798830957
778
+ }
779
+ }
780
+ },
781
+ {
782
+ "model": "Last model",
783
+ "levels": [
784
+ 1,
785
+ 0.5
786
+ ],
787
+ "metrics": {
788
+ "Ent-I": {
789
+ "precision": 0.6689889953284229,
790
+ "recall": 0.7853889766699642,
791
+ "f1": 0.7225309911836113
792
+ },
793
+ "Ent-C": {
794
+ "precision": 0.5955509816334741,
795
+ "recall": 0.698579255269107,
796
+ "f1": 0.6429639709296068
797
+ }
798
+ }
799
+ },
800
+ {
801
+ "model": "Last model",
802
+ "levels": [
803
+ 2,
804
+ 0.5
805
+ ],
806
+ "metrics": {
807
+ "Ent-I": {
808
+ "precision": 0.6690151994928205,
809
+ "recall": 0.7854819221109903,
810
+ "f1": 0.7225856048813833
811
+ },
812
+ "Ent-C": {
813
+ "precision": 0.5955829969123758,
814
+ "recall": 0.6986721144018027,
815
+ "f1": 0.6430219589658785
816
+ }
817
+ }
818
+ },
819
+ {
820
+ "model": "Last model",
821
+ "levels": [
822
+ 3,
823
+ 0.5
824
+ ],
825
+ "metrics": {
826
+ "Ent-I": {
827
+ "precision": 0.6689622417472738,
828
+ "recall": 0.7854819221109903,
829
+ "f1": 0.7225547145937161
830
+ },
831
+ "Ent-C": {
832
+ "precision": 0.5955358556271999,
833
+ "recall": 0.6986721144018027,
834
+ "f1": 0.6429944829173719
835
+ }
836
+ }
837
+ },
838
+ {
839
+ "model": "Last model",
840
+ "levels": [
841
+ 0,
842
+ 0.6
843
+ ],
844
+ "metrics": {
845
+ "Ent-I": {
846
+ "precision": 0.6573486139068783,
847
+ "recall": 0.7890138488699795,
848
+ "f1": 0.7171883530858584
849
+ },
850
+ "Ent-C": {
851
+ "precision": 0.5843592721636978,
852
+ "recall": 0.7008078744538018,
853
+ "f1": 0.637307882221732
854
+ }
855
+ }
856
+ },
857
+ {
858
+ "model": "Last model",
859
+ "levels": [
860
+ 1,
861
+ 0.6
862
+ ],
863
+ "metrics": {
864
+ "Ent-I": {
865
+ "precision": 0.6570942023371336,
866
+ "recall": 0.7890138488699795,
867
+ "f1": 0.7170369069423463
868
+ },
869
+ "Ent-C": {
870
+ "precision": 0.5841331269345323,
871
+ "recall": 0.7008078744538018,
872
+ "f1": 0.6371733666477617
873
+ }
874
+ }
875
+ },
876
+ {
877
+ "model": "Last model",
878
+ "levels": [
879
+ 2,
880
+ 0.6
881
+ ],
882
+ "metrics": {
883
+ "Ent-I": {
884
+ "precision": 0.6570698862312072,
885
+ "recall": 0.7891067943110055,
886
+ "f1": 0.7170608058518843
887
+ },
888
+ "Ent-C": {
889
+ "precision": 0.5841201052464138,
890
+ "recall": 0.7009007335864974,
891
+ "f1": 0.6372039965603209
892
+ }
893
+ }
894
+ },
895
+ {
896
+ "model": "Last model",
897
+ "levels": [
898
+ 3,
899
+ 0.6
900
+ ],
901
+ "metrics": {
902
+ "Ent-I": {
903
+ "precision": 0.6570190373002189,
904
+ "recall": 0.7891067943110055,
905
+ "f1": 0.7170305258463955
906
+ },
907
+ "Ent-C": {
908
+ "precision": 0.5840749052073171,
909
+ "recall": 0.7009007335864974,
910
+ "f1": 0.6371771012369758
911
+ }
912
+ }
913
+ },
914
+ {
915
+ "model": "Last model",
916
+ "levels": [
917
+ 0,
918
+ 0.8
919
+ ],
920
+ "metrics": {
921
+ "Ent-I": {
922
+ "precision": 0.6334001633618226,
923
+ "recall": 0.7928246119520468,
924
+ "f1": 0.7042020919825571
925
+ },
926
+ "Ent-C": {
927
+ "precision": 0.5627737768204301,
928
+ "recall": 0.7038722258327571,
929
+ "f1": 0.6254641422683939
930
+ }
931
+ }
932
+ },
933
+ {
934
+ "model": "Last model",
935
+ "levels": [
936
+ 1,
937
+ 0.8
938
+ ],
939
+ "metrics": {
940
+ "Ent-I": {
941
+ "precision": 0.6331180880274377,
942
+ "recall": 0.7928246119520468,
943
+ "f1": 0.7040277269868673
944
+ },
945
+ "Ent-C": {
946
+ "precision": 0.5625231910942021,
947
+ "recall": 0.7038722258327571,
948
+ "f1": 0.6253093499537059
949
+ }
950
+ }
951
+ },
952
+ {
953
+ "model": "Last model",
954
+ "levels": [
955
+ 2,
956
+ 0.8
957
+ ],
958
+ "metrics": {
959
+ "Ent-I": {
960
+ "precision": 0.6330513505486546,
961
+ "recall": 0.7929175573930728,
962
+ "f1": 0.7040231021388053
963
+ },
964
+ "Ent-C": {
965
+ "precision": 0.5624721768804255,
966
+ "recall": 0.7039650849654526,
967
+ "f1": 0.6253144669553269
968
+ }
969
+ }
970
+ },
971
+ {
972
+ "model": "Last model",
973
+ "levels": [
974
+ 3,
975
+ 0.8
976
+ ],
977
+ "metrics": {
978
+ "Ent-I": {
979
+ "precision": 0.6330043778284239,
980
+ "recall": 0.7929175573930728,
981
+ "f1": 0.7039940534877868
982
+ },
983
+ "Ent-C": {
984
+ "precision": 0.5624304473621466,
985
+ "recall": 0.7039650849654526,
986
+ "f1": 0.625288678664709
987
+ }
988
+ }
989
+ },
990
+ {
991
+ "model": "Last model",
992
+ "levels": [
993
+ 0,
994
+ 1.0
995
+ ],
996
+ "metrics": {
997
+ "Ent-I": {
998
+ "precision": 0.5998878845206363,
999
+ "recall": 0.7957059206238537,
1000
+ "f1": 0.684059124143036
1001
+ },
1002
+ "Ent-C": {
1003
+ "precision": 0.5324038394167081,
1004
+ "recall": 0.7056365493539737,
1005
+ "f1": 0.6069004024131198
1006
+ }
1007
+ }
1008
+ },
1009
+ {
1010
+ "model": "Last model",
1011
+ "levels": [
1012
+ 1,
1013
+ 1.0
1014
+ ],
1015
+ "metrics": {
1016
+ "Ent-I": {
1017
+ "precision": 0.5995517893405703,
1018
+ "recall": 0.7957059206238537,
1019
+ "f1": 0.6838405574435115
1020
+ },
1021
+ "Ent-C": {
1022
+ "precision": 0.5321055948459267,
1023
+ "recall": 0.7056365493539737,
1024
+ "f1": 0.6067065819241425
1025
+ }
1026
+ }
1027
+ },
1028
+ {
1029
+ "model": "Last model",
1030
+ "levels": [
1031
+ 2,
1032
+ 1.0
1033
+ ],
1034
+ "metrics": {
1035
+ "Ent-I": {
1036
+ "precision": 0.599495868925501,
1037
+ "recall": 0.7957988660648798,
1038
+ "f1": 0.6838384999504639
1039
+ },
1040
+ "Ent-C": {
1041
+ "precision": 0.5320638476613468,
1042
+ "recall": 0.7057294084866693,
1043
+ "f1": 0.6067137619122933
1044
+ }
1045
+ }
1046
+ },
1047
+ {
1048
+ "model": "Last model",
1049
+ "levels": [
1050
+ 3,
1051
+ 1.0
1052
+ ],
1053
+ "metrics": {
1054
+ "Ent-I": {
1055
+ "precision": 0.5994538962398659,
1056
+ "recall": 0.7957988660648798,
1057
+ "f1": 0.6838111922871924
1058
+ },
1059
+ "Ent-C": {
1060
+ "precision": 0.532026601329694,
1061
+ "recall": 0.7057294084866693,
1062
+ "f1": 0.6066895456687551
1063
+ }
1064
+ }
1065
+ }
1066
+ ]
20_entities_top_50_25/results/20_entities_top_50_25_test_df_.xlsx ADDED
Binary file (8.8 kB). View file