SS3M commited on
Commit
df31aca
·
verified ·
1 Parent(s): 16b4b69

Upload 4.2_add_span_rerank_branch_4.5's state dict

Browse files
.gitattributes CHANGED
@@ -93,3 +93,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
93
  4_margin_loss_C_4.1/logs/4_margin_loss_C_4.1_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
94
  4.1_entities_no_ce_4.2/logs/4.1_entities_no_ce_4.2_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
95
  4.2_add_span_rerank_branch_4.3/logs/4.2_add_span_rerank_branch_4.3_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
 
 
93
  4_margin_loss_C_4.1/logs/4_margin_loss_C_4.1_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
94
  4.1_entities_no_ce_4.2/logs/4.1_entities_no_ce_4.2_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
95
  4.2_add_span_rerank_branch_4.3/logs/4.2_add_span_rerank_branch_4.3_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
96
+ 4.2_add_span_rerank_branch_4.5/logs/4.2_add_span_rerank_branch_4.5_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
4.2_add_span_rerank_branch_4.5/4.2_add_span_rerank_branch_4.5.py ADDED
@@ -0,0 +1,2751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [code]
2
+ get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch] pytorch-crf')
3
+
4
+ # %% [code]
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
12
+ import torch.nn.functional as F
13
+ import albumentations as albu
14
+ from transformers import AutoTokenizer, AutoModel
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from positional_encodings.torch_encodings import PositionalEncoding1D
18
+ from torchcrf import CRF
19
+
20
+ from sklearn.metrics import f1_score
21
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
22
+ from scipy.spatial.transform import Rotation as R
23
+ from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold
24
+ from sklearn.metrics import precision_recall_fscore_support
25
+ from timm.utils import ModelEmaV3
26
+ import timm
27
+
28
+ import os
29
+ import gc
30
+ import json
31
+ from pathlib import Path
32
+ import pickle
33
+ from tqdm.auto import tqdm
34
+ import copy
35
+ import numpy as np
36
+ import pandas as pd
37
+ import polars as pl
38
+ from PIL import Image
39
+ import time
40
+ from tqdm import tqdm
41
+ from matplotlib import pyplot as plt
42
+ import seaborn as sns
43
+ from multiprocessing import Manager as MemoryManager
44
+ from functools import lru_cache
45
+ import shutil
46
+ import glob
47
+ import cv2
48
+ import random
49
+ import re
50
+ import joblib
51
+ import math
52
+ from huggingface_hub import HfApi, snapshot_download
53
+ import evaluate
54
+ from underthesea import word_tokenize as vi_tokenize_tool
55
+ import spacy
56
+ en_tokenize_tool = spacy.load("en_core_web_sm")
57
+ from collections import defaultdict, Counter
58
+
59
+ # %% [code]
60
+ # Global config
61
+ SEEDS = [26092004]
62
+ topk = 1
63
+ nfolds = 5
64
+ only_fold_idx = 0
65
+ test_only = 0
66
+ debug_only = 0
67
+
68
+ # Config thư mục
69
+ dataset = 'kltn/only_entities' # conll003, ontonotes, phoner, vietbio, vietmed, vimed, kltn/only_entities, kltn/raw
70
+ root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' ## Thư mục chứa file train, val, test
71
+ train_dir = f'{root_dir}'
72
+ # val_dir = f'{root_dir}/val'
73
+ test_dir = f'{root_dir}'
74
+
75
+ # Config checkpoints
76
+
77
+ # Config training
78
+ epochs = 18 if not debug_only else 2
79
+ batch_size = 32
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+ # # Thêm biến toàn cục nào đó vào đây
82
+ repo_name = 'SS3M/kltn-experiments'
83
+ state_dict_save_name = "4.2_add_span_rerank_branch_4.5"
84
+ checkpoints_dir = state_dict_save_name
85
+ pretrained_dir = "/kaggle/working"
86
+ os.makedirs(f'{checkpoints_dir}', exist_ok=True)
87
+
88
+ backbone_model_name = "bert-base-uncased" if dataset in ["conll003", "ontonotes"] else "vinai/phobert-base"
89
+ word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == dataset in ["conll003", "ontonotes"] else vi_tokenize_tool(text)
90
+ max_len_dict = {
91
+ 'kltn/raw': 256,
92
+ 'kltn/only_entities': 68,
93
+ 'conll003': 46,
94
+ 'ontonotes': 61,
95
+ 'phoner': 68,
96
+ 'vietbio': 125,
97
+ 'vietmed': 36,
98
+ 'vimed': 100,
99
+ }
100
+ zero_entities_rate_dict = {
101
+ 'kltn/raw': 1000,
102
+ 'kltn/only_entities': 0.2,
103
+ 'conll003': 1000, # mean keep all zero-entities samples
104
+ 'ontonotes': 1000,
105
+ 'phoner': 1000,
106
+ 'vietbio': 1000,
107
+ 'vietmed': 1000,
108
+ 'vimed': 1000,
109
+ }
110
+
111
+ max_len = max_len_dict[dataset]
112
+ max_n_parts = 1
113
+ max_span_len = 10
114
+ zero_entities_rate = zero_entities_rate_dict[dataset]
115
+
116
+ # Trainer
117
+ trainer_params = {
118
+ "training_time": "00:11:30:00",
119
+ "eval_mode": "max",
120
+ "topk": topk,
121
+ "save_name": state_dict_save_name,
122
+ "save_best": True,
123
+ "save_last": True,
124
+ "device": device,
125
+ "logging": True,
126
+ "logging_file": True,
127
+ "checkpoints_dir": checkpoints_dir,
128
+ "early_stopping": 30,
129
+ "eval_from_ratio": 0.4,
130
+ "eval_every": 1,
131
+ "schedule_in_step": False,
132
+ "use_ema": True,
133
+ "ema_from_ratio": 0.3,
134
+ "ema_decay": 0.9995,
135
+ "max_grad_norm": 200.0,
136
+ "return_best": True,
137
+ "return_last": True,
138
+ }
139
+
140
+ # Memory
141
+ train_memory_params = {
142
+ 'max_len': max_len,
143
+ 'max_n_parts': max_n_parts,
144
+ }
145
+ val_memory_params = {
146
+ 'max_len': max_len,
147
+ 'max_n_parts': max_n_parts,
148
+ }
149
+
150
+ # Data Loader
151
+ def seed_worker(worker_id):
152
+ worker_seed = torch.initial_seed() % 2**32
153
+ np.random.seed(worker_seed)
154
+ random.seed(worker_seed)
155
+
156
+ train_loader_params = {
157
+ 'batch_size': batch_size,
158
+ 'shuffle': True,
159
+ 'pin_memory':True,
160
+ 'num_workers': 2,
161
+ 'drop_last': False,
162
+ 'worker_init_fn': seed_worker,
163
+ 'persistent_workers': False,
164
+ }
165
+ val_loader_params = {
166
+ 'batch_size': batch_size,
167
+ 'shuffle': False,
168
+ 'pin_memory':True,
169
+ 'num_workers': 1,
170
+ 'drop_last': False,
171
+ 'worker_init_fn': seed_worker,
172
+ 'persistent_workers': False,
173
+ }
174
+
175
+ # Model
176
+ model_params = {
177
+ 'backbone_model_name': backbone_model_name,
178
+ 'max_r': 6,
179
+ }
180
+
181
+ # Loss Func
182
+ loss_func_params = {
183
+ 'lambda_span': 1.0,
184
+ 'lambda_margin': 1.0,
185
+ 'margin': 0.5,
186
+ }
187
+ eval_func_params = {}
188
+
189
+ # Optim
190
+ optim_params = {
191
+ 'name': 'AdamW',
192
+ 'lr': 1e-4,
193
+ 'weight_decay': 1e-4,
194
+ }
195
+ scheduler_params = {
196
+ 'name': 'CosineAnnealingLR',
197
+ 'T_max': 20, # Số epoch để hoàn thành một chu kỳ giảm LR
198
+ 'eta_min': 1e-6 # Learning rate nhỏ nhất trong chu kỳ
199
+ }
200
+
201
+ # %% [code]
202
+ def set_seed(seed=42):
203
+ random.seed(seed)
204
+ np.random.seed(seed)
205
+ torch.manual_seed(seed)
206
+ torch.cuda.manual_seed(seed)
207
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
208
+ torch.use_deterministic_algorithms(False)
209
+ torch.backends.cudnn.deterministic = True
210
+ torch.backends.cudnn.benchmark = False
211
+ os.environ['PYTHONHASHSEED'] = str(seed)
212
+
213
+ # %% [code]
214
+ class CustomLoss(nn.Module):
215
+ def __init__(
216
+ self,
217
+ lambda_margin=1.0,
218
+ lambda_span=1.0,
219
+ margin=1.0
220
+ ):
221
+ super().__init__()
222
+
223
+ self.lambda_margin = lambda_margin
224
+ self.lambda_span = lambda_span
225
+ self.margin = margin
226
+
227
+ # =========================================================
228
+ # token margin
229
+ # =========================================================
230
+
231
+ def margin_loss(self, logits, labels):
232
+ """
233
+ logits: (N, C)
234
+ labels: (N,)
235
+ """
236
+
237
+ valid_mask = labels != -100
238
+
239
+ logits = logits[valid_mask]
240
+ labels = labels[valid_mask]
241
+
242
+ if len(labels) == 0:
243
+ return logits.new_tensor(0.0)
244
+
245
+ # positive logit
246
+ pos_logit = logits.gather(
247
+ dim=1,
248
+ index=labels.unsqueeze(-1)
249
+ ).squeeze(-1)
250
+
251
+ # hardest negative
252
+ neg_logits = logits.clone()
253
+
254
+ neg_logits.scatter_(
255
+ 1,
256
+ labels.unsqueeze(-1),
257
+ float("-inf")
258
+ )
259
+
260
+ hardest_neg = neg_logits.max(dim=-1).values
261
+
262
+ # margin ranking
263
+ loss = F.relu(
264
+ self.margin - pos_logit + hardest_neg
265
+ )
266
+
267
+ return loss.mean()
268
+
269
+ # =========================================================
270
+ # span listwise margin
271
+ # =========================================================
272
+
273
+ def span_loss(self, span_scores, labels):
274
+ """
275
+ span_scores: (B, N, M)
276
+ labels: (B, N, M)
277
+
278
+ labels:
279
+ -100 -> ignore
280
+ 0 -> negative
281
+ >0 -> positive
282
+ """
283
+
284
+ device = span_scores.device
285
+
286
+ B, N, M = span_scores.shape
287
+ if N == 0 or M == 0:
288
+ return span_scores.new_tensor(0.0)
289
+
290
+ # =====================================================
291
+ # masks
292
+ # =====================================================
293
+
294
+ valid_mask = labels != -100
295
+ pos_mask = labels > 0
296
+ neg_mask = labels == 0
297
+
298
+ # =====================================================
299
+ # mỗi N phải có ít nhất 1 positive
300
+ # =====================================================
301
+
302
+ has_pos = pos_mask.any(dim=-1)
303
+ # (B, N)
304
+
305
+ # =====================================================
306
+ # masked scores
307
+ # =====================================================
308
+ pos_scores = span_scores.masked_fill(
309
+ ~pos_mask,
310
+ float("-inf")
311
+ )
312
+
313
+ neg_scores = span_scores.masked_fill(
314
+ ~neg_mask,
315
+ float("-inf")
316
+ )
317
+
318
+ # =====================================================
319
+ # hardest positive / hardest negative
320
+ # =====================================================
321
+
322
+ best_pos = pos_scores.max(dim=-1).values
323
+ # (B, N)
324
+
325
+ best_neg = neg_scores.max(dim=-1).values
326
+ # (B, N)
327
+
328
+ # =====================================================
329
+ # nếu không có negative
330
+ # =====================================================
331
+
332
+ no_neg = ~neg_mask.any(dim=-1)
333
+
334
+ best_neg = torch.where(
335
+ no_neg,
336
+ torch.zeros_like(best_neg),
337
+ best_neg
338
+ )
339
+
340
+ # =====================================================
341
+ # margin ranking
342
+ # =====================================================
343
+
344
+ loss = F.relu(
345
+ self.margin - best_pos + best_neg
346
+ )
347
+
348
+ # =====================================================
349
+ # chỉ tính những N có positive
350
+ # =====================================================
351
+
352
+ loss = loss[has_pos]
353
+
354
+ if loss.numel() == 0:
355
+ return span_scores.new_tensor(0.0)
356
+
357
+ return loss.mean()
358
+
359
+ # =========================================================
360
+ # forward
361
+ # =========================================================
362
+
363
+ def forward(
364
+ self,
365
+ start_logits, start_labels,
366
+ end_logits, end_labels,
367
+ span_scores, labels
368
+ ):
369
+
370
+ # =====================================================
371
+ # flatten token logits
372
+ # =====================================================
373
+
374
+ B, L, C = start_logits.shape
375
+
376
+ start_logits_flat = start_logits.view(B * L, C)
377
+ start_labels_flat = start_labels.view(-1)
378
+
379
+ end_logits_flat = end_logits.view(B * L, C)
380
+ end_labels_flat = end_labels.view(-1)
381
+
382
+ # =====================================================
383
+ # token margin
384
+ # =====================================================
385
+
386
+ start_margin = self.margin_loss(
387
+ start_logits_flat,
388
+ start_labels_flat
389
+ )
390
+
391
+ end_margin = self.margin_loss(
392
+ end_logits_flat,
393
+ end_labels_flat
394
+ )
395
+
396
+ token_margin_loss = (
397
+ start_margin + end_margin
398
+ )
399
+
400
+ # =====================================================
401
+ # span loss
402
+ # =====================================================
403
+
404
+ span_loss = self.span_loss(
405
+ span_scores,
406
+ labels
407
+ )
408
+
409
+ # =====================================================
410
+ # total
411
+ # =====================================================
412
+
413
+ total_loss = (
414
+ self.lambda_margin * token_margin_loss
415
+ +
416
+ self.lambda_span * span_loss
417
+ )
418
+
419
+ return {
420
+ "total": total_loss,
421
+ "token_margin_loss": token_margin_loss,
422
+ "span_loss": span_loss,
423
+ "start_margin": start_margin,
424
+ "end_margin": end_margin,
425
+ }
426
+
427
+ # %% [code]
428
+ ## Viết eval_fn vào đây
429
+
430
+ # Bỏ hết eval_fn và trọng số vào đây
431
+ class CustomEvalFn(nn.Module):
432
+ def __init__(self):
433
+ super().__init__()
434
+
435
+ def compute_f1(self, tp, fp, fn):
436
+ precision = tp / (tp + fp + 1e-8)
437
+ recall = tp / (tp + fn + 1e-8)
438
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
439
+ return precision, recall, f1
440
+
441
+ def forward(self, pred, gold):
442
+ pred_set = set(pred)
443
+ gold_set = set(gold)
444
+
445
+ tp = len(pred_set & gold_set)
446
+ fp = len(pred_set - gold_set)
447
+ fn = len(gold_set - pred_set)
448
+
449
+ precision, recall, f1 = self.compute_f1(tp, fp, fn)
450
+
451
+ return {
452
+ f"precision": precision,
453
+ f"recall": recall,
454
+ f"f1": f1,
455
+ }
456
+
457
+ class SpanErrorAnalyzer:
458
+ def __init__(self, pad_token_id=0):
459
+ self.pad_token_id = pad_token_id
460
+
461
+ # ===== helper =====
462
+ def _to_set(self, data):
463
+ """
464
+ data: list of (b, tuple(ids))
465
+ -> dict[b] = set(tuple(ids))
466
+ """
467
+ res = defaultdict(set)
468
+ for b, ids in data:
469
+ ids = tuple([i for i in ids if i != self.pad_token_id])
470
+ if len(ids) > 0:
471
+ res[b].add(ids)
472
+ return res
473
+
474
+ def _iou(self, a, b):
475
+ """
476
+ a, b: tuple(ids)
477
+ """
478
+ set_a, set_b = set(a), set(b)
479
+ inter = len(set_a & set_b)
480
+ union = len(set_a | set_b)
481
+ if union == 0:
482
+ return 0.0
483
+ return inter / union
484
+
485
+ def _boundary_error(self, pred, gold):
486
+ """
487
+ đo lệch boundary dựa trên overlap prefix/suffix
488
+ """
489
+ # left match
490
+ left = 0
491
+ for i in range(min(len(pred), len(gold))):
492
+ if pred[i] == gold[i]:
493
+ left += 1
494
+ else:
495
+ break
496
+
497
+ # right match
498
+ right = 0
499
+ for i in range(1, min(len(pred), len(gold)) + 1):
500
+ if pred[-i] == gold[-i]:
501
+ right += 1
502
+ else:
503
+ break
504
+
505
+ return {
506
+ "left_match": left,
507
+ "right_match": right,
508
+ "pred_len": len(pred),
509
+ "gold_len": len(gold),
510
+ }
511
+
512
+ # ===== main =====
513
+ def analyze(self, preds, golds):
514
+ pred_map = self._to_set(preds)
515
+ gold_map = self._to_set(golds)
516
+
517
+ all_batches = set(pred_map.keys()) | set(gold_map.keys())
518
+
519
+ stats = Counter()
520
+
521
+ detailed_errors = []
522
+
523
+ for b in all_batches:
524
+ pset = pred_map.get(b, set())
525
+ gset = gold_map.get(b, set())
526
+
527
+ matched_gold = set()
528
+
529
+ # ===== check predictions =====
530
+ for p in pset:
531
+ if p in gset:
532
+ stats["exact_match"] += 1
533
+ matched_gold.add(p)
534
+ else:
535
+ # tìm gold gần nhất
536
+ best_iou = 0
537
+ best_g = None
538
+
539
+ for g in gset:
540
+ iou = self._iou(p, g)
541
+ if iou > best_iou:
542
+ best_iou = iou
543
+ best_g = g
544
+
545
+ if best_iou > 0:
546
+ stats["partial_match"] += 1
547
+
548
+ boundary = self._boundary_error(p, best_g)
549
+
550
+ detailed_errors.append({
551
+ "type": "boundary_error",
552
+ "batch": b,
553
+ "pred": p,
554
+ "gold": best_g,
555
+ "iou": best_iou,
556
+ **boundary
557
+ })
558
+ else:
559
+ if b not in gold_map:
560
+ stats["no_event_sample"] += 1
561
+ err_type = "no_event_sample"
562
+ else:
563
+ stats["completely_wrong"] += 1
564
+ err_type = "completely_wrong"
565
+
566
+ detailed_errors.append({
567
+ "type": err_type,
568
+ "batch": b,
569
+ "pred": p
570
+ })
571
+
572
+ # ===== check missing =====
573
+ for g in gset:
574
+ if g not in matched_gold:
575
+ # check if any pred overlaps
576
+ overlap = any(self._iou(p, g) > 0 for p in pset)
577
+
578
+ if overlap:
579
+ stats["miss_with_overlap"] += 1
580
+ else:
581
+ stats["miss"] += 1
582
+
583
+ detailed_errors.append({
584
+ "type": "miss",
585
+ "batch": b,
586
+ "gold": g
587
+ })
588
+
589
+ return {
590
+ "summary": {
591
+ "exact_match": (stats["exact_match"], stats["exact_match"] / len(preds)),
592
+ "partial_match": (stats["partial_match"], stats["partial_match"] / len(preds)),
593
+ "no_event_sample": (stats["no_event_sample"], stats["no_event_sample"] / len(preds)),
594
+ "completely_wrong": (stats["completely_wrong"], stats["completely_wrong"] / len(preds)),
595
+ "miss": (stats["miss"], stats["miss"] / len(golds)),
596
+ "miss_with_overlap": (stats["miss_with_overlap"], stats["miss_with_overlap"] / len(golds)),
597
+ },
598
+ "details": detailed_errors
599
+ }
600
+
601
+ # %% [code]
602
+ class DataParallelProxy(nn.DataParallel):
603
+
604
+ def __getattr__(self, name):
605
+ try:
606
+ return super().__getattr__(name)
607
+
608
+ except AttributeError:
609
+
610
+ attr = getattr(self.module, name)
611
+
612
+ if callable(attr):
613
+
614
+ def wrapper(*args, **kwargs):
615
+ return self._parallel_apply_method(
616
+ name,
617
+ *args,
618
+ **kwargs
619
+ )
620
+
621
+ return wrapper
622
+
623
+ return attr
624
+
625
+ # =========================================================
626
+ # parallel custom method
627
+ # =========================================================
628
+
629
+ def _parallel_apply_method(self, method_name, *inputs, **kwargs):
630
+
631
+ if not self.device_ids:
632
+ return getattr(self.module, method_name)(*inputs, **kwargs)
633
+
634
+ inputs_scattered, kwargs_scattered = self.scatter(
635
+ inputs,
636
+ kwargs,
637
+ self.device_ids
638
+ )
639
+
640
+ replicas = self.replicate(
641
+ self.module,
642
+ self.device_ids[:len(inputs_scattered)]
643
+ )
644
+
645
+ outputs = self.parallel_apply(
646
+ [getattr(replica, method_name) for replica in replicas],
647
+ inputs_scattered,
648
+ kwargs_scattered
649
+ )
650
+
651
+ return self._custom_gather(
652
+ outputs,
653
+ self.output_device
654
+ )
655
+
656
+ # =========================================================
657
+ # OVERRIDE FORWARD GATHER
658
+ # =========================================================
659
+
660
+ def gather(self, outputs, output_device):
661
+
662
+ return self._custom_gather(
663
+ outputs,
664
+ output_device
665
+ )
666
+
667
+ # =========================================================
668
+ # recursive gather
669
+ # =========================================================
670
+
671
+ def _custom_gather(self, outputs, output_device):
672
+
673
+ first = outputs[0]
674
+
675
+ # =====================================================
676
+ # tensor
677
+ # =====================================================
678
+
679
+ if torch.is_tensor(first):
680
+
681
+ return self._gather_tensor(
682
+ outputs,
683
+ output_device
684
+ )
685
+
686
+ # =====================================================
687
+ # tuple
688
+ # =====================================================
689
+
690
+ if isinstance(first, tuple):
691
+
692
+ return tuple(
693
+ self._custom_gather(
694
+ list(items),
695
+ output_device
696
+ )
697
+ for items in zip(*outputs)
698
+ )
699
+
700
+ # =====================================================
701
+ # list
702
+ # =====================================================
703
+
704
+ if isinstance(first, list):
705
+
706
+ # list[tensor]
707
+ if len(first) > 0 and torch.is_tensor(first[0]):
708
+
709
+ return self._gather_tensor_list(
710
+ outputs,
711
+ output_device
712
+ )
713
+
714
+ merged = []
715
+
716
+ for out in outputs:
717
+ merged.extend(out)
718
+
719
+ return merged
720
+
721
+ # =====================================================
722
+ # dict
723
+ # =====================================================
724
+
725
+ if isinstance(first, dict):
726
+
727
+ return {
728
+ k: self._custom_gather(
729
+ [o[k] for o in outputs],
730
+ output_device
731
+ )
732
+ for k in first.keys()
733
+ }
734
+
735
+ # =====================================================
736
+ # fallback
737
+ # =====================================================
738
+
739
+ return outputs
740
+
741
+ # =========================================================
742
+ # gather tensor with auto pad
743
+ # =========================================================
744
+
745
+ def _gather_tensor(self, tensors, output_device):
746
+
747
+ # move same device first
748
+ tensors = [
749
+ t.to(output_device)
750
+ for t in tensors
751
+ ]
752
+
753
+ # =====================================================
754
+ # fast path
755
+ # =====================================================
756
+
757
+ try:
758
+ return torch.cat(tensors, dim=0)
759
+
760
+ except RuntimeError:
761
+ pass
762
+
763
+ # =====================================================
764
+ # auto max shape
765
+ # =====================================================
766
+
767
+ max_shape = list(tensors[0].shape)
768
+
769
+ for t in tensors[1:]:
770
+
771
+ for d in range(len(max_shape)):
772
+
773
+ max_shape[d] = max(
774
+ max_shape[d],
775
+ t.shape[d]
776
+ )
777
+
778
+ # =====================================================
779
+ # pad tensors
780
+ # =====================================================
781
+
782
+ padded = []
783
+
784
+ for t in tensors:
785
+
786
+ pad = []
787
+
788
+ # reverse order for F.pad
789
+ for d in reversed(range(len(max_shape))):
790
+
791
+ # never pad batch dim
792
+ if d == 0:
793
+ pad.extend([0, 0])
794
+ continue
795
+
796
+ diff = max_shape[d] - t.shape[d]
797
+
798
+ pad.extend([0, diff])
799
+
800
+ t = F.pad(t, pad)
801
+
802
+ padded.append(t)
803
+
804
+ return torch.cat(padded, dim=0)
805
+
806
+ # =========================================================
807
+ # list[tensor]
808
+ # =========================================================
809
+
810
+ def _gather_tensor_list(self, outputs, output_device):
811
+
812
+ merged = []
813
+
814
+ for out in outputs:
815
+ merged.extend(out)
816
+
817
+ return self._gather_tensor(
818
+ merged,
819
+ output_device
820
+ )
821
+
822
+ # %% [code]
823
+ ## Viết cấu trúc model vào đây
824
+ def extract_spans_and_labels(start_logits, end_logits):
825
+ """
826
+ Args:
827
+ start_logits: Tensor (B, L, C)
828
+ end_logits: Tensor (B, L, C)
829
+
830
+ Returns:
831
+ spans: Tensor (B, N, 2)
832
+ padding = (0, 0)
833
+
834
+ labels: Tensor (B, N)
835
+ padding = -100
836
+
837
+ Nếu không extract được span nào:
838
+ spans.shape = (B, 0, 2)
839
+ labels.shape = (B, 0)
840
+ """
841
+
842
+ start_labels = start_logits.argmax(dim=-1) # (B, L)
843
+ end_labels = end_logits.argmax(dim=-1) # (B, L)
844
+
845
+ B, L = start_labels.shape
846
+
847
+ batch_spans = []
848
+ batch_labels = []
849
+
850
+ max_n = 0
851
+
852
+ for bidx in range(B):
853
+
854
+ used_start = set()
855
+ used_end = set()
856
+
857
+ spans = []
858
+ labels = []
859
+
860
+ for s in range(L):
861
+
862
+ s_label = start_labels[bidx, s].item()
863
+
864
+ # bỏ qua O
865
+ if s_label == 0:
866
+ continue
867
+
868
+ if s in used_start:
869
+ continue
870
+
871
+ nearest_e = None
872
+
873
+ # tìm end gần nhất cùng class
874
+ for e in range(s, L):
875
+
876
+ if e in used_end:
877
+ continue
878
+
879
+ e_label = end_labels[bidx, e].item()
880
+
881
+ if e_label == s_label:
882
+ nearest_e = e
883
+ break
884
+
885
+ if nearest_e is None:
886
+ continue
887
+
888
+ used_start.add(s)
889
+ used_end.add(nearest_e)
890
+
891
+ spans.append([s, nearest_e])
892
+ labels.append(s_label)
893
+
894
+ batch_spans.append(spans)
895
+ batch_labels.append(labels)
896
+
897
+ max_n = max(max_n, len(spans))
898
+
899
+ # =========================================================
900
+ # không extract được gì
901
+ # =========================================================
902
+
903
+ if max_n == 0:
904
+
905
+ spans = torch.empty(
906
+ (B, 0, 2),
907
+ dtype=torch.long,
908
+ device=start_logits.device
909
+ )
910
+
911
+ labels = torch.empty(
912
+ (B, 0),
913
+ dtype=torch.long,
914
+ device=start_logits.device
915
+ )
916
+
917
+ return spans, labels
918
+
919
+ # =========================================================
920
+ # padding
921
+ # =========================================================
922
+
923
+ padded_spans = []
924
+ padded_labels = []
925
+
926
+ for spans, labels in zip(batch_spans, batch_labels):
927
+
928
+ pad_n = max_n - len(spans)
929
+
930
+ spans = spans + [[0, 0]] * pad_n
931
+ labels = labels + [-100] * pad_n
932
+
933
+ padded_spans.append(spans)
934
+ padded_labels.append(labels)
935
+
936
+ spans = torch.tensor(
937
+ padded_spans,
938
+ dtype=torch.long,
939
+ device=start_logits.device
940
+ ) # (B, N, 2)
941
+
942
+ labels = torch.tensor(
943
+ padded_labels,
944
+ dtype=torch.long,
945
+ device=start_logits.device
946
+ ) # (B, N)
947
+
948
+ return spans, labels
949
+
950
+ def expand_spans(spans, r, attention_mask):
951
+ """
952
+ Args:
953
+ spans: Tensor (B, N, 2)
954
+ r: Tensor (B, N)
955
+ attention_mask: Tensor (B, L)
956
+
957
+ Returns:
958
+ expanded_spans: Tensor (B, N, M, 2)
959
+
960
+ Với:
961
+ M = (2*max_r + 1)^2
962
+
963
+ padding = (0, 0)
964
+
965
+ Nếu N = 0:
966
+ return shape (B, 0, 0, 2)
967
+ """
968
+
969
+ device = spans.device
970
+
971
+ B, N, _ = spans.shape
972
+
973
+ # =========================================================
974
+ # empty case
975
+ # =========================================================
976
+
977
+ if N == 0:
978
+ return torch.empty(
979
+ (B, 0, 0, 2),
980
+ dtype=torch.long,
981
+ device=device
982
+ )
983
+
984
+ # =========================================================
985
+ # max radius
986
+ # =========================================================
987
+
988
+ max_r = r.max().item()
989
+
990
+ # =========================================================
991
+ # create all (ds, de)
992
+ # =========================================================
993
+
994
+ shifts = torch.arange(
995
+ -max_r,
996
+ max_r + 1,
997
+ device=device
998
+ )
999
+
1000
+ ds_grid, de_grid = torch.meshgrid(
1001
+ shifts,
1002
+ shifts,
1003
+ indexing='ij'
1004
+ )
1005
+
1006
+ ds = ds_grid.reshape(-1) # (M,)
1007
+ de = de_grid.reshape(-1) # (M,)
1008
+
1009
+ M = ds.size(0)
1010
+
1011
+ # =========================================================
1012
+ # original spans
1013
+ # =========================================================
1014
+
1015
+ spans = spans.unsqueeze(2) # (B, N, 1, 2)
1016
+
1017
+ start = spans[..., 0] # (B, N, 1)
1018
+ end = spans[..., 1] # (B, N, 1)
1019
+
1020
+ # =========================================================
1021
+ # expand independently
1022
+ # =========================================================
1023
+
1024
+ expanded_s = start + ds.view(1, 1, M)
1025
+ expanded_e = end + de.view(1, 1, M)
1026
+
1027
+ expanded = torch.stack(
1028
+ [expanded_s, expanded_e],
1029
+ dim=-1
1030
+ ) # (B, N, M, 2)
1031
+
1032
+ # =========================================================
1033
+ # valid shift range
1034
+ # =========================================================
1035
+
1036
+ valid_shift = (
1037
+ (ds.abs().view(1, 1, M) <= r.unsqueeze(-1))
1038
+ &
1039
+ (de.abs().view(1, 1, M) <= r.unsqueeze(-1))
1040
+ )
1041
+
1042
+ # =========================================================
1043
+ # non-padding spans
1044
+ # =========================================================
1045
+
1046
+ non_pad_span = (
1047
+ spans.squeeze(2).sum(dim=-1) != 0
1048
+ ) # (B, N)
1049
+
1050
+ # =========================================================
1051
+ # boundary check
1052
+ # =========================================================
1053
+
1054
+ valid_length = attention_mask.sum(dim=-1)
1055
+ valid_length = valid_length.view(B, 1, 1)
1056
+
1057
+ valid_boundary = (
1058
+ (expanded_s > 0)
1059
+ &
1060
+ (expanded_e < valid_length)
1061
+ &
1062
+ (expanded_s <= expanded_e)
1063
+ )
1064
+
1065
+ # =========================================================
1066
+ # final valid mask
1067
+ # =========================================================
1068
+
1069
+ valid_mask = (
1070
+ valid_shift
1071
+ &
1072
+ non_pad_span.unsqueeze(-1)
1073
+ &
1074
+ valid_boundary
1075
+ )
1076
+
1077
+ # =========================================================
1078
+ # flatten valid spans
1079
+ # =========================================================
1080
+
1081
+ expanded_flat = expanded.reshape(B, N, M, 2)
1082
+ valid_flat = valid_mask.reshape(B, N, M)
1083
+
1084
+ # =========================================================
1085
+ # số valid max toàn batch
1086
+ # =========================================================
1087
+
1088
+ valid_counts = valid_flat.sum(dim=-1) # (B, N)
1089
+
1090
+ M_max = valid_counts.max().item()
1091
+
1092
+ # =========================================================
1093
+ # allocate output
1094
+ # =========================================================
1095
+
1096
+ filtered_expanded = torch.zeros(
1097
+ (B, N, M_max, 2),
1098
+ dtype=expanded.dtype,
1099
+ device=device
1100
+ )
1101
+
1102
+ # =========================================================
1103
+ # gather only valid spans
1104
+ # =========================================================
1105
+
1106
+ for b in range(B):
1107
+ for n in range(N):
1108
+
1109
+ cur_valid = valid_flat[b, n] # (M,)
1110
+
1111
+ valid_spans = expanded_flat[b, n][cur_valid]
1112
+
1113
+ k = valid_spans.size(0)
1114
+
1115
+ if k > 0:
1116
+ filtered_expanded[b, n, :k] = valid_spans
1117
+
1118
+ return filtered_expanded
1119
+
1120
+ def get_span_reprs(hidden, spans):
1121
+ """
1122
+ Args:
1123
+ hidden: (B, L, H)
1124
+ spans: (B, N, M, 2)
1125
+
1126
+ Return:
1127
+ span_repr: (B, N, M, 4H)
1128
+
1129
+ Nếu N = 0:
1130
+ return shape (B, 0, 0, 4H)
1131
+ """
1132
+
1133
+ B, N, M, _ = spans.shape
1134
+ H = hidden.size(-1)
1135
+
1136
+ # =========================================================
1137
+ # empty case
1138
+ # =========================================================
1139
+
1140
+ if N == 0:
1141
+ return torch.empty(
1142
+ (B, 0, 0, 4 * H),
1143
+ dtype=hidden.dtype,
1144
+ device=hidden.device
1145
+ )
1146
+
1147
+ batch_idx = torch.arange(
1148
+ B,
1149
+ device=hidden.device
1150
+ ).view(B, 1, 1)
1151
+
1152
+ start_idx = spans[..., 0] # (B, N, M)
1153
+ end_idx = spans[..., 1] # (B, N, M)
1154
+
1155
+ # =========================================================
1156
+ # gather hidden
1157
+ # =========================================================
1158
+
1159
+ start_h = hidden[batch_idx, start_idx]
1160
+ end_h = hidden[batch_idx, end_idx]
1161
+
1162
+ # (B, N, M, H)
1163
+
1164
+ # =========================================================
1165
+ # span representation
1166
+ # =========================================================
1167
+
1168
+ span_repr = torch.cat(
1169
+ [
1170
+ start_h,
1171
+ end_h,
1172
+ end_h - start_h,
1173
+ end_h * start_h
1174
+ ],
1175
+ dim=-1
1176
+ ) # (B, N, M, 4H)
1177
+
1178
+ return span_repr
1179
+
1180
+ class MLP(nn.Module):
1181
+ def __init__(self, in_size, hid_size, out_size):
1182
+ super().__init__()
1183
+ self.mlp = nn.Sequential(
1184
+ nn.Linear(in_size, hid_size),
1185
+ nn.ReLU(),
1186
+ nn.Linear(hid_size, out_size)
1187
+ )
1188
+
1189
+ def forward(self, x):
1190
+ return self.mlp(x)
1191
+
1192
+ class IEModel(nn.Module):
1193
+ def __init__(self, backbone_model_name, num_labels, max_r):
1194
+ super().__init__()
1195
+ self.max_r = max_r
1196
+
1197
+ self.encoder = AutoModel.from_pretrained(backbone_model_name)
1198
+ hidden_size = self.encoder.config.hidden_size
1199
+
1200
+ self.start_classifier = MLP(hidden_size, hidden_size, num_labels)
1201
+ self.end_classifier = MLP(hidden_size, hidden_size, num_labels)
1202
+
1203
+ self.span_scorer = MLP(4*hidden_size, hidden_size, 1)
1204
+
1205
+ def encode(self, input_ids, attention_mask):
1206
+ B, n_parts, L = input_ids.shape
1207
+ input_ids = input_ids.view(-1, L)
1208
+ attention_mask = attention_mask.view(-1, L)
1209
+
1210
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
1211
+ hidden_states = outputs.last_hidden_state # B * n_parts, L, H
1212
+
1213
+ hidden_states = hidden_states.view(B, n_parts, L, -1).reshape(B, n_parts*L, -1) # B, L, H
1214
+ attention_mask = attention_mask.view(B, n_parts, L).reshape(B, n_parts*L) # B, L
1215
+ return hidden_states, attention_mask
1216
+
1217
+ def get_logits(self, hidden_states):
1218
+ start_logits = self.start_classifier(hidden_states) # B, N, classes
1219
+ end_logits = self.end_classifier(hidden_states) # B, N, classes
1220
+ return start_logits, end_logits
1221
+
1222
+ def get_scores(self, expanded_span_reprs):
1223
+ return self.span_scorer(expanded_span_reprs).squeeze(-1)
1224
+
1225
+ def forward(self, input_ids, attention_mask, spans=None):
1226
+ hidden_states, attention_mask = self.encode(input_ids, attention_mask)
1227
+ start_logits, end_logits = self.get_logits(hidden_states)
1228
+
1229
+ if spans is None:
1230
+ spans, _ = extract_spans_and_labels(start_logits, end_logits) # (B, N, 2)
1231
+ B, N, _ = spans.shape
1232
+ r = torch.randint(1, self.max_r+1, (B, N), device=spans.device)
1233
+ expanded_spans = expand_spans(spans, r, attention_mask) # (B, N, M, 2)
1234
+
1235
+ expanded_span_reprs = get_span_reprs(hidden_states, expanded_spans) # (B, N, M, 4H)
1236
+ expanded_span_scores = self.get_scores(expanded_span_reprs) # (B, N, M)
1237
+
1238
+ return start_logits, end_logits, expanded_spans, expanded_span_scores
1239
+
1240
+ def test_model():
1241
+ model = DataParallelProxy(IEModel(backbone_model_name, 7, 6)).to(device)
1242
+ model.eval()
1243
+ total_params = sum(p.numel() for p in model.parameters())
1244
+ print(f"Total params: {total_params:,}")
1245
+
1246
+ vocab_size = model.module.encoder.config.vocab_size
1247
+ max_len = model.module.encoder.config.max_position_embeddings
1248
+
1249
+ bz = 32
1250
+ i = torch.randint(0, vocab_size, (bz, 5, 10)).to(device)
1251
+ a = torch.ones(bz, 5, 10).to(device)
1252
+ g = torch.ones(bz, 3, 2, dtype=torch.long).to(device)
1253
+
1254
+ with torch.no_grad():
1255
+ r = model(i, a)
1256
+
1257
+ if type(r) == tuple:
1258
+ print([r[i].shape if type(r[i]) == type(torch.Tensor()) else len(r[i]) for i in range(len(r))])
1259
+ else:
1260
+ print(r.shape)
1261
+
1262
+ test_model()
1263
+
1264
+ # %% [code]
1265
+ def configure_optimizers(network, optim_params, scheduler_params):
1266
+ try:
1267
+ optim_params = copy.copy(optim_params)
1268
+ scheduler_params = copy.copy(scheduler_params)
1269
+
1270
+ optim_name = optim_params.pop('name')
1271
+ scheduler_name = scheduler_params.pop('name')
1272
+
1273
+ optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None)
1274
+ scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None)
1275
+
1276
+ if optimizer_cls is None:
1277
+ raise ValueError(f"Optimizer '{optim_name}' is not available!")
1278
+
1279
+ optimizer = optimizer_cls(network.parameters(), **optim_params)
1280
+
1281
+ scheduler = None
1282
+ if scheduler_params and scheduler_cls: # Chỉ tạo scheduler nếu có tham số
1283
+ scheduler = scheduler_cls(optimizer, **scheduler_params)
1284
+
1285
+ return optimizer, scheduler
1286
+
1287
+ except KeyError as e:
1288
+ raise ValueError(f"Missing {e} in config!!")
1289
+
1290
+ def freeze(self, model):
1291
+ model.eval()
1292
+ for param in model.parameters():
1293
+ param.requires_grad = False
1294
+
1295
+ def unfreeze(self, model):
1296
+ model.train()
1297
+ for param in model.parameters():
1298
+ param.requires_grad = True
1299
+
1300
+ def reduce_batch_size(loader, ratio=0.5):
1301
+ new_bs = max(1, int(loader.batch_size * ratio))
1302
+
1303
+ shuffle = isinstance(loader.sampler, RandomSampler)
1304
+
1305
+ new_loader = DataLoader(
1306
+ dataset=loader.dataset,
1307
+ batch_size=new_bs,
1308
+ shuffle=shuffle,
1309
+ sampler=None if shuffle else loader.sampler,
1310
+ num_workers=loader.num_workers,
1311
+ collate_fn=loader.collate_fn,
1312
+ pin_memory=loader.pin_memory,
1313
+ drop_last=loader.drop_last,
1314
+ timeout=loader.timeout,
1315
+ worker_init_fn=loader.worker_init_fn,
1316
+ multiprocessing_context=loader.multiprocessing_context,
1317
+ generator=loader.generator,
1318
+ prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None,
1319
+ persistent_workers=loader.persistent_workers,
1320
+ pin_memory_device=loader.pin_memory_device
1321
+ )
1322
+
1323
+ return new_loader
1324
+
1325
+ def list_to_tuple(x):
1326
+ if isinstance(x, (list, tuple)):
1327
+ return tuple(list_to_tuple(i) for i in x)
1328
+ return x
1329
+
1330
+ def fmt(x):
1331
+ if isinstance(x, float):
1332
+ return round(x, 5)
1333
+ if isinstance(x, dict):
1334
+ return {k: fmt(v) for k, v in x.items()}
1335
+ if isinstance(x, list):
1336
+ return [fmt(v) for v in x]
1337
+ return x
1338
+
1339
+ class ModelEmaV3Proxy(ModelEmaV3):
1340
+ def __getattr__(self, name):
1341
+ try:
1342
+ return super().__getattr__(name)
1343
+ except AttributeError:
1344
+ return getattr(self.module, name)
1345
+
1346
+ def align(spans, gold_spans, gold_labels):
1347
+ """
1348
+ Args:
1349
+ spans: (B, N1, M, 2)
1350
+ gold_spans: (B, N2, 2)
1351
+ gold_labels: (B, N2)
1352
+
1353
+ Returns:
1354
+ labels: (B, N1, M)
1355
+
1356
+ matched -> gold_labels
1357
+ padding (0,0) -> -100
1358
+ unmatched -> 0
1359
+ """
1360
+
1361
+ device = spans.device
1362
+
1363
+ B, N1, M, _ = spans.shape
1364
+ N2 = gold_spans.shape[1]
1365
+
1366
+ # =========================================================
1367
+ # compare spans
1368
+ # =========================================================
1369
+
1370
+ pred = spans.unsqueeze(3)
1371
+ # (B, N1, M, 1, 2)
1372
+
1373
+ gold = gold_spans.unsqueeze(1).unsqueeze(1)
1374
+ # (B, 1, 1, N2, 2)
1375
+
1376
+ matched = (pred == gold).all(dim=-1)
1377
+ # (B, N1, M, N2)
1378
+
1379
+ # =========================================================
1380
+ # default labels = 0
1381
+ # =========================================================
1382
+
1383
+ labels = torch.zeros(
1384
+ (B, N1, M),
1385
+ dtype=gold_labels.dtype,
1386
+ device=device
1387
+ )
1388
+
1389
+ # =========================================================
1390
+ # assign matched labels
1391
+ # =========================================================
1392
+
1393
+ has_match = matched.any(dim=-1)
1394
+ # (B, N1, M)
1395
+
1396
+ matched_idx = matched.float().argmax(dim=-1)
1397
+ # (B, N1, M)
1398
+
1399
+ expanded_gold_labels = gold_labels.unsqueeze(1).unsqueeze(1).expand(
1400
+ B, N1, M, N2
1401
+ )
1402
+
1403
+ gathered_labels = torch.gather(
1404
+ expanded_gold_labels,
1405
+ dim=-1,
1406
+ index=matched_idx.unsqueeze(-1)
1407
+ ).squeeze(-1)
1408
+
1409
+ labels = torch.where(
1410
+ has_match,
1411
+ gathered_labels,
1412
+ labels
1413
+ )
1414
+
1415
+ # =========================================================
1416
+ # padding span -> -100
1417
+ # =========================================================
1418
+
1419
+ is_padding = (spans.sum(dim=-1) == 0)
1420
+
1421
+ labels = torch.where(
1422
+ is_padding,
1423
+ torch.full_like(labels, -100),
1424
+ labels
1425
+ )
1426
+
1427
+ return labels
1428
+
1429
+ def select_best_spans(spans, span_scores):
1430
+ """
1431
+ Args:
1432
+ spans: (B, N, M, 2)
1433
+ span_scores: (B, N, M)
1434
+
1435
+ Returns:
1436
+ best_spans: (B, N, 2)
1437
+ """
1438
+
1439
+ # =========================================================
1440
+ # padding spans -> -inf
1441
+ # =========================================================
1442
+
1443
+ is_padding = (spans.sum(dim=-1) == 0)
1444
+ # (B, N, M)
1445
+
1446
+ masked_scores = span_scores.masked_fill(
1447
+ is_padding,
1448
+ float("-inf")
1449
+ )
1450
+
1451
+ # =========================================================
1452
+ # best index theo chiều M
1453
+ # =========================================================
1454
+
1455
+ best_idx = masked_scores.argmax(dim=-1)
1456
+ # (B, N)
1457
+
1458
+ # =========================================================
1459
+ # gather spans
1460
+ # =========================================================
1461
+
1462
+ best_spans = torch.gather(
1463
+ spans,
1464
+ dim=2,
1465
+ index=best_idx.unsqueeze(-1).unsqueeze(-1).expand(
1466
+ -1, -1, 1, 2
1467
+ )
1468
+ ).squeeze(2)
1469
+
1470
+ # =========================================================
1471
+ # nếu toàn bộ M đều là padding
1472
+ # -> output = (0,0)
1473
+ # =========================================================
1474
+
1475
+ all_padding = is_padding.all(dim=-1)
1476
+ # (B, N)
1477
+
1478
+ best_spans = torch.where(
1479
+ all_padding.unsqueeze(-1),
1480
+ torch.zeros_like(best_spans),
1481
+ best_spans
1482
+ )
1483
+
1484
+ return best_spans
1485
+
1486
+ def extract_entities(input_ids, spans, labels, id2label):
1487
+ """
1488
+ Args:
1489
+ input_ids: (B, L)
1490
+ spans: (B, N, 2)
1491
+ labels: (B, N)
1492
+
1493
+ Returns:
1494
+ List[
1495
+ (
1496
+ bidx,
1497
+ (
1498
+ tuple(token_ids),
1499
+ label_name
1500
+ )
1501
+ )
1502
+ ]
1503
+ """
1504
+
1505
+ B, N, _ = spans.shape
1506
+
1507
+ results = []
1508
+
1509
+ for bidx in range(B):
1510
+
1511
+ for n in range(N):
1512
+
1513
+ lb = labels[bidx, n].item()
1514
+
1515
+ # ignore padding / negative
1516
+ if lb <= 0:
1517
+ continue
1518
+
1519
+ s, e = spans[bidx, n].tolist()
1520
+
1521
+ # skip padding span
1522
+ if s == 0 and e == 0:
1523
+ continue
1524
+
1525
+ token_ids = tuple(
1526
+ input_ids[bidx, s:e + 1].tolist()
1527
+ )
1528
+
1529
+ results.append(
1530
+ (
1531
+ bidx,
1532
+ (
1533
+ token_ids,
1534
+ id2label[lb]
1535
+ )
1536
+ )
1537
+ )
1538
+
1539
+ return results
1540
+
1541
+ class Trainer:
1542
+ def __init__(
1543
+ self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0,
1544
+ logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu',
1545
+ schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True
1546
+ ):
1547
+ self.ema_net = None
1548
+
1549
+ self.training_time = self._time_str_to_seconds(training_time)
1550
+ self.mode = eval_mode
1551
+ self.topk = topk
1552
+ self.device = device
1553
+ self.logging = logging if logging < epochs else 1
1554
+ self.logging_file = logging_file
1555
+ self.checkpoints_dir = checkpoints_dir
1556
+ self.early_stopping = early_stopping
1557
+ self.eval_from_ratio = eval_from_ratio
1558
+ self.eval_every = eval_every
1559
+ self.save_name = save_name
1560
+ self.save_best = save_best
1561
+ self.save_last = save_last
1562
+ self.return_best = return_best
1563
+ self.return_last = return_last
1564
+ self.max_grad_norm = max_grad_norm
1565
+ self.schedule_in_step = schedule_in_step
1566
+ self.use_ema = use_ema
1567
+ self.ema_from_ratio = ema_from_ratio
1568
+ self.ema_decay = ema_decay
1569
+
1570
+ self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]]
1571
+ self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0)
1572
+
1573
+ def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, id2label=None):
1574
+ if eval_fn is None:
1575
+ if self.mode == "max":
1576
+ eval_fn = lambda *x: -loss_fn(*x)
1577
+ else:
1578
+ eval_fn = lambda *x: loss_fn(*x)
1579
+
1580
+ if torch.cuda.device_count() > 1:
1581
+ network = DataParallelProxy(network)
1582
+ network = network.to(self.device)
1583
+
1584
+ if not start_training_time:
1585
+ start_training_time = time.time()
1586
+
1587
+ start_ema = int(epochs * self.ema_from_ratio)
1588
+ start_eval = int(epochs * self.eval_from_ratio)
1589
+
1590
+ if val_loader is None:
1591
+ print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!')
1592
+ else:
1593
+ model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc'
1594
+ start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else ''
1595
+ print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str)
1596
+
1597
+ training_log = {}
1598
+ for epoch in range(start_epoch, epochs+start_epoch):
1599
+ if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema:
1600
+ self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
1601
+
1602
+ try:
1603
+ teaching_rate = math.cos(math.pi / 2 * epoch / epochs)
1604
+ train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, teaching_rate)
1605
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
1606
+ logging_dict.update(train_loss_epoch_dict)
1607
+
1608
+ if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0:
1609
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1610
+
1611
+ val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, id2label)
1612
+ update = self._update_best_network(eval_net, val_score, epoch)
1613
+ logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update})
1614
+ logging_dict.update(val_score_dict)
1615
+ if not self.schedule_in_step and scheduler:
1616
+ scheduler.step()
1617
+
1618
+ except RuntimeError as e:
1619
+ if "out of memory" in str(e).lower():
1620
+ print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...")
1621
+ torch.cuda.empty_cache()
1622
+ gc.collect()
1623
+ if torch.cuda.is_available():
1624
+ torch.cuda.synchronize()
1625
+ print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared")
1626
+
1627
+ train_loader = reduce_batch_size(train_loader, ratio=0.5)
1628
+ if val_loader is not None:
1629
+ val_loader = reduce_batch_size(val_loader, ratio=0.5)
1630
+
1631
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')}
1632
+ else:
1633
+ raise
1634
+
1635
+ training_log[epoch] = logging_dict
1636
+ if self.is_early_stopping(epoch):
1637
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...')
1638
+ break
1639
+ if self.logging:
1640
+ if epoch % self.logging == 0:
1641
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict))
1642
+ else:
1643
+ print(f'{epoch}...', end=' ')
1644
+
1645
+ if self._at_time_limit(start_training_time):
1646
+ print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}')
1647
+ break
1648
+
1649
+ if self.logging_file:
1650
+ os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True)
1651
+ with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f:
1652
+ f.write(json.dumps(training_log))
1653
+
1654
+ if self.use_ema and self.ema_net is not None:
1655
+ self._save_state_dict(self.ema_net.module)
1656
+ else:
1657
+ self._save_state_dict(network)
1658
+ print(f'[Trainer CallBack] 📢 Kết thúc training.\n')
1659
+
1660
+ best_model, last_model = None, None
1661
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1662
+ if self.return_best :
1663
+ best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict()
1664
+ best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()}
1665
+ if self.return_last:
1666
+ last_model = eval_net.state_dict()
1667
+ last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()}
1668
+
1669
+ del network
1670
+ torch.cuda.empty_cache()
1671
+ gc.collect()
1672
+ return training_log, best_model, last_model
1673
+
1674
+ def _time_str_to_seconds(self, time_str):
1675
+ days, hours, minutes, seconds = map(int, time_str.split(":"))
1676
+ return days * 86400 + hours * 3600 + minutes * 60 + seconds
1677
+
1678
+ def _update_best_network(self, network, val_score, epoch):
1679
+ topk = max(1, self.topk)
1680
+ self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}])
1681
+ self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk]
1682
+ if val_score in [x[0] for x in self.best_stage]:
1683
+ return True
1684
+ return False
1685
+
1686
+ def is_early_stopping(self, epoch):
1687
+ if self.best_stage[0][1] is None:
1688
+ return False
1689
+ if not self.early_stopping:
1690
+ return False
1691
+ return epoch - self.best_stage[0][1] >= self.early_stopping
1692
+
1693
+ def _at_time_limit(self, start_training_time):
1694
+ return time.time() - start_training_time >= self.training_time
1695
+
1696
+ def _save_state_dict(self, network):
1697
+ if self.topk <= 0:
1698
+ return
1699
+
1700
+ if self.save_best:
1701
+ for r in range(self.topk):
1702
+ os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True)
1703
+
1704
+ for rank, (score, epoch, state_dict) in enumerate(self.best_stage):
1705
+ if state_dict is None:
1706
+ continue
1707
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()}
1708
+ torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth')
1709
+ if self.save_last:
1710
+ os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True)
1711
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()}
1712
+ torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth')
1713
+
1714
+ def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn, teaching_rate):
1715
+ network.train()
1716
+ total_loss = 0
1717
+ total_loss_dict = {}
1718
+ for batch_idx, batch in enumerate(train_loader):
1719
+ optimizer.zero_grad()
1720
+ with torch.autocast(device_type=self.device, dtype=torch.float16):
1721
+ loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn, teaching_rate)
1722
+
1723
+ for k, v in loss_dict.items():
1724
+ t = total_loss_dict.get(k, 0)
1725
+ total_loss_dict[k] = t + v
1726
+ self.grad_scaler.scale(loss).backward()
1727
+ self.grad_scaler.unscale_(optimizer)
1728
+ grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm)
1729
+ # print(grad_norm) # Bỏ cmt dòng này để biết nên chọn max_grad_norm bằng bao nhiêu...
1730
+ self.grad_scaler.step(optimizer)
1731
+ self.grad_scaler.update()
1732
+ if self.schedule_in_step and scheduler:
1733
+ scheduler.step()
1734
+ if self.use_ema and self.ema_net is not None:
1735
+ self.ema_net.update(network)
1736
+ total_loss += loss
1737
+ return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()}
1738
+
1739
+ def _eval_epoch(self, network, val_loader, eval_fn, id2label):
1740
+ network.eval()
1741
+ total_score = 0.0
1742
+ total_score_dict = {}
1743
+ object_lists = None # sẽ init sau
1744
+
1745
+ with torch.no_grad():
1746
+ for batch_idx, batch in enumerate(val_loader):
1747
+ score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, id2label)
1748
+ total_score += score
1749
+
1750
+ for k, v in score_dict.items():
1751
+ t = total_score_dict.get(k, 0)
1752
+ total_score_dict[k] = t + v
1753
+
1754
+ if objects:
1755
+ if object_lists is None:
1756
+ object_lists = [[] for _ in range(len(objects))]
1757
+
1758
+ for i, obj in enumerate(objects):
1759
+ object_lists[i].append(obj.detach())
1760
+
1761
+ if object_lists is not None:
1762
+ object_arrays = [
1763
+ torch.concat(obj_list, dim=0).cpu().numpy()
1764
+ for obj_list in object_lists
1765
+ ]
1766
+ else:
1767
+ object_arrays = []
1768
+
1769
+ return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays
1770
+
1771
+ def _cal_loss(self, network, batch, batch_idx, loss_fn, teaching_rate):
1772
+ # Bạn cần override _cal_loss để tính loss
1773
+ input_ids = batch['input_ids'].to(self.device)
1774
+ attention_mask = batch['attention_mask'].to(self.device)
1775
+ gold_spans = batch['gold_spans'].to(self.device)
1776
+ gold_labels = batch['gold_labels'].to(self.device)
1777
+ start_labels = batch['start_labels'].to(self.device)
1778
+ end_labels = batch['end_labels'].to(self.device)
1779
+
1780
+ choice = random.random()
1781
+ if choice < teaching_rate:
1782
+ start_logits, end_logits, spans, span_scores = network(input_ids, attention_mask, gold_spans)
1783
+ else:
1784
+ start_logits, end_logits, spans, span_scores = network(input_ids, attention_mask)
1785
+
1786
+ labels = align(spans, gold_spans, gold_labels)
1787
+
1788
+ loss_dict = loss_fn(
1789
+ start_logits, start_labels,
1790
+ end_logits, end_labels,
1791
+ span_scores, labels
1792
+ )
1793
+ return loss_dict['total'], loss_dict
1794
+
1795
+ def _cal_val_score(self, network, batch, batch_idx, eval_fn, id2label):
1796
+ # Bạn cần override _cal_val_score để tính val score, list bên cạnh là để trả về y hay pred gì đó (nếu cần)
1797
+ input_ids = batch['input_ids'].to(self.device)
1798
+ attention_mask = batch['attention_mask'].to(self.device)
1799
+ gold_entities = batch['gold_entities']
1800
+
1801
+ B, _, _ = input_ids.shape
1802
+
1803
+ start_logits, end_logits, spans, span_scores = network(input_ids, attention_mask)
1804
+ _, labels = extract_spans_and_labels(start_logits, end_logits)
1805
+ spans = select_best_spans(spans, span_scores)
1806
+
1807
+ pred_ids = extract_entities(input_ids.reshape(B, -1), spans, labels, id2label)
1808
+ pred_ids = list_to_tuple(pred_ids)
1809
+
1810
+ gold_ids = list_to_tuple(gold_entities)
1811
+
1812
+ score_dict = eval_fn(pred_ids, gold_ids)
1813
+ return score_dict['f1'], score_dict, []
1814
+
1815
+ # %% [code]
1816
+ class PhoBERTSpanAligner:
1817
+ def __init__(self, tokenizer, max_len):
1818
+ self.tokenizer = tokenizer
1819
+ self.max_len = max_len
1820
+
1821
+ # ===== 1. Extract discontinuous spans =====
1822
+ def extract_spans(self, sample):
1823
+ entity_spans = []
1824
+
1825
+ for event in sample["entities"]:
1826
+ entity_type = event["label"]
1827
+ spans = [tuple(event["offset"])]
1828
+ entity_spans.append({
1829
+ "spans": spans,
1830
+ "label": entity_type
1831
+ })
1832
+
1833
+ return entity_spans
1834
+
1835
+ # ===== 2. Word offsets =====
1836
+ def build_word_offsets(self, text, words):
1837
+ offsets = []
1838
+ pointer = 0
1839
+
1840
+ for word in words:
1841
+ start = text.find(word, pointer)
1842
+ end = start + len(word)
1843
+ offsets.append((start, end))
1844
+ pointer = end
1845
+
1846
+ return offsets
1847
+
1848
+ # ===== 3. Char → word =====
1849
+ def char_span_to_word_span(self, word_offsets, start, end):
1850
+ start_word = None
1851
+ end_word = None
1852
+
1853
+ for i, (w_start, w_end) in enumerate(word_offsets):
1854
+ if w_start <= start < w_end:
1855
+ start_word = i
1856
+ if w_start < end <= w_end:
1857
+ end_word = i
1858
+
1859
+ return start_word, end_word
1860
+
1861
+ # ===== 4. Word → subword =====
1862
+ def word_to_subword_map(self, words):
1863
+ mapping = []
1864
+ subword_index = 1 # <s>
1865
+
1866
+ for word in words:
1867
+ sub_tokens = self.tokenizer.tokenize(word)
1868
+ start = subword_index
1869
+ end = subword_index + len(sub_tokens) - 1
1870
+ mapping.append((start, end))
1871
+ subword_index += len(sub_tokens)
1872
+
1873
+ return mapping
1874
+
1875
+ # ===== 5. Span → subword =====
1876
+ def span_to_subword(self, word_offsets, word_subword_map, spans):
1877
+ sub_spans = []
1878
+
1879
+ for span_start, span_end in spans:
1880
+ w_start, w_end = self.char_span_to_word_span(
1881
+ word_offsets, span_start, span_end
1882
+ )
1883
+ if w_start is None or w_end is None:
1884
+ continue
1885
+
1886
+ sub_start = word_subword_map[w_start][0]
1887
+ sub_end = word_subword_map[w_end][1]
1888
+ sub_spans.append((sub_start, sub_end))
1889
+
1890
+ return sub_spans
1891
+
1892
+ def extract_valid_spans(self, sub_spans):
1893
+ valid_spans = []
1894
+ for s, e in sub_spans:
1895
+ if s < 0 or e < 0 or s >= self.max_len or e >= self.max_len or s > e:
1896
+ continue
1897
+ valid_spans.append((s, e))
1898
+ return valid_spans
1899
+
1900
+ def encode(self, sample):
1901
+ text = sample["text"]
1902
+ entities = self.extract_spans(sample)
1903
+
1904
+ # ===== 1. Word tokenize =====
1905
+ words = word_tokenize(text)
1906
+ sentence = " ".join(words)
1907
+
1908
+ # ===== 2. Mapping =====
1909
+ word_offsets = self.build_word_offsets(text, words)
1910
+ word_subword_map = self.word_to_subword_map(words)
1911
+
1912
+ # ===== 3. Tokenize FULL =====
1913
+ encoding = self.tokenizer(
1914
+ sentence,
1915
+ max_length=self.max_len,
1916
+ truncation=True,
1917
+ padding="max_length",
1918
+ return_tensors="pt"
1919
+ )
1920
+ input_ids = encoding["input_ids"][0]
1921
+ attention_mask = encoding["attention_mask"][0]
1922
+
1923
+ # ===== 5. Convert spans =====
1924
+ entities_gold_spans = []
1925
+
1926
+ for ent in entities:
1927
+ label = ent["label"]
1928
+
1929
+ sub_spans = self.span_to_subword(
1930
+ word_offsets,
1931
+ word_subword_map,
1932
+ ent["spans"]
1933
+ )
1934
+ valid_spans = self.extract_valid_spans(sub_spans)
1935
+ if len(valid_spans) == 0:
1936
+ continue
1937
+ entities_gold_spans.append((tuple(valid_spans), label))
1938
+
1939
+ return {
1940
+ "input_ids": input_ids,
1941
+ "attention_mask": attention_mask,
1942
+ "entities_gold_spans": entities_gold_spans,
1943
+ }
1944
+
1945
+ def generate_spans(attention_mask, max_span_len):
1946
+ seq_len = attention_mask.sum().item() - 2
1947
+ spans = []
1948
+ for i in range(1, seq_len+1):
1949
+ for j in range(i, min(i+max_span_len, seq_len+1)):
1950
+ spans.append((i, j))
1951
+ return spans
1952
+
1953
+ def match_gold_labels(
1954
+ gold_spans, # (N, 2)
1955
+ gold_labels, # (N,)
1956
+ pred_spans, # (M, 2)
1957
+ default_label=-100
1958
+ ):
1959
+ """
1960
+ Return:
1961
+ pred_labels: (M,)
1962
+ """
1963
+
1964
+ pred_labels = torch.full(
1965
+ (pred_spans.size(0),),
1966
+ default_label,
1967
+ dtype=gold_labels.dtype,
1968
+ device=gold_labels.device
1969
+ )
1970
+ if gold_spans.size(0) == 0:
1971
+ return pred_labels
1972
+
1973
+ # (M, N)
1974
+ matched = (pred_spans[:, None, :] == gold_spans[None, :, :]).all(dim=-1)
1975
+ has_match = matched.any(dim=1)
1976
+
1977
+ # lấy index gold đầu tiên match
1978
+ gold_idx = matched.float().argmax(dim=1)
1979
+
1980
+ pred_labels[has_match] = gold_labels[gold_idx[has_match]]
1981
+
1982
+ return pred_labels
1983
+
1984
+ class KLTNDataset(Dataset):
1985
+ def __init__(self, all_data, using_idxes, label2id, tokenizer, max_len, max_n_parts):
1986
+ super().__init__()
1987
+ self.tokenizer = tokenizer
1988
+ self.aligner = PhoBERTSpanAligner(tokenizer, max_len*max_n_parts)
1989
+ self.all_data = all_data
1990
+ self.using_idxes = using_idxes
1991
+ self.label2id = label2id
1992
+ self.max_len = max_len
1993
+ self.max_n_parts = max_n_parts
1994
+
1995
+ def __len__(self):
1996
+ return len(self.using_idxes)
1997
+
1998
+ def __getitem__(self, idx):
1999
+ ridx = self.using_idxes[idx]
2000
+ sample = self.all_data[ridx]
2001
+ result = self.aligner.encode(sample)
2002
+
2003
+ input_ids = result["input_ids"].squeeze(0)
2004
+ attention_mask = result["attention_mask"].squeeze(0)
2005
+ entities_gold_spans = result["entities_gold_spans"]
2006
+
2007
+ all_spans = torch.tensor(generate_spans(attention_mask, 10))
2008
+ gold_spans = torch.tensor([spans[0] for spans, _ in entities_gold_spans], dtype=torch.long) if entities_gold_spans else torch.empty(0, 2, dtype=torch.long)
2009
+ gold_labels = torch.tensor([self.label2id[label] for _, label in entities_gold_spans], dtype=torch.long) if entities_gold_spans else torch.empty(0, dtype=torch.long)
2010
+ all_labels = match_gold_labels(
2011
+ gold_spans, # (N, 2)
2012
+ gold_labels, # (N,)
2013
+ all_spans, # (M, 2)
2014
+ default_label=0
2015
+ )
2016
+
2017
+ # Get label
2018
+ gold_entities = []
2019
+ start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
2020
+ end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
2021
+ for spans, label in entities_gold_spans:
2022
+ s, e = spans[0]
2023
+
2024
+ start_labels[s] = self.label2id[f'{label}']
2025
+ end_labels[e] = self.label2id[f'{label}']
2026
+
2027
+ gold_entities.append((tuple(input_ids[s:e+1].tolist()), label))
2028
+
2029
+ input_ids = input_ids.reshape(self.max_n_parts, self.max_len)
2030
+ attention_mask = attention_mask.reshape(self.max_n_parts, self.max_len)
2031
+
2032
+ n_valid_parts = math.ceil(attention_mask.sum().item() / self.max_len)
2033
+ input_ids = input_ids[:n_valid_parts]
2034
+ attention_mask = attention_mask[:n_valid_parts]
2035
+ start_labels = start_labels[:n_valid_parts*self.max_len]
2036
+ end_labels = end_labels[:n_valid_parts*self.max_len]
2037
+
2038
+ return {
2039
+ "input_ids": input_ids,
2040
+ "attention_mask": attention_mask,
2041
+
2042
+ "all_spans": all_spans,
2043
+ "all_labels": all_labels,
2044
+
2045
+ "gold_spans": gold_spans,
2046
+ "gold_labels": gold_labels,
2047
+
2048
+ "start_labels": start_labels,
2049
+ "end_labels": end_labels,
2050
+
2051
+ "gold_entities": gold_entities,
2052
+ }
2053
+
2054
+ def _pad_batch(tensor_list, pad_value=0):
2055
+ """
2056
+ tensor_list: list of tensors
2057
+ mỗi tensor shape: (Nk, n_parts_i, max_len_i)
2058
+
2059
+ return:
2060
+ padded tensor shape: (B, max_Nk, max_n_parts, max_len)
2061
+ """
2062
+
2063
+ # lấy max toàn batch
2064
+ max_Nk = max(t.size(0) for t in tensor_list)
2065
+ max_n_parts = max(t.size(1) for t in tensor_list)
2066
+ max_len = max(t.size(2) for t in tensor_list)
2067
+
2068
+ padded = []
2069
+
2070
+ for t in tensor_list:
2071
+ Nk, n_parts_i, max_len_i = t.shape
2072
+
2073
+ # pad chiều n_parts và max_len trước
2074
+ if n_parts_i < max_n_parts or max_len_i < max_len:
2075
+ new_t = t.new_full(
2076
+ (Nk, max_n_parts, max_len),
2077
+ pad_value
2078
+ )
2079
+ new_t[:, :n_parts_i, :max_len_i] = t
2080
+ t = new_t
2081
+
2082
+ # pad chiều Nk
2083
+ if Nk < max_Nk:
2084
+ pad_tensor = t.new_full(
2085
+ (max_Nk - Nk, max_n_parts, max_len),
2086
+ pad_value
2087
+ )
2088
+ t = torch.cat([t, pad_tensor], dim=0)
2089
+
2090
+ padded.append(t)
2091
+
2092
+ return torch.stack(padded) # (B, max_Nk, max_n_parts, max_len)
2093
+
2094
+ def collate_fn(batch):
2095
+ gold_entities = []
2096
+ for bidx, b in enumerate(batch):
2097
+ for entity in b['gold_entities']:
2098
+ gold_entities.append([bidx, entity])
2099
+
2100
+ input_ids = [b["input_ids"].unsqueeze(-1) for b in batch]
2101
+ attention_mask = [b["attention_mask"].unsqueeze(-1) for b in batch]
2102
+
2103
+ all_spans = [b["all_spans"].unsqueeze(-1) for b in batch]
2104
+ all_labels = [b["all_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
2105
+
2106
+ gold_spans = [b["gold_spans"].unsqueeze(-1) for b in batch]
2107
+ gold_labels = [b["gold_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
2108
+
2109
+ start_labels = [b["start_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
2110
+ end_labels = [b["end_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
2111
+
2112
+ # pad theo Nk
2113
+ input_ids = _pad_batch(input_ids, pad_value=0).squeeze(-1)
2114
+ attention_mask = _pad_batch(attention_mask, pad_value=0).squeeze(-1)
2115
+
2116
+ all_spans = _pad_batch(all_spans, pad_value=0).squeeze(-1)
2117
+ all_labels = _pad_batch(all_labels, pad_value=-100).squeeze(-1).squeeze(-1)
2118
+
2119
+ gold_spans = _pad_batch(gold_spans, pad_value=0).squeeze(-1)
2120
+ gold_labels = _pad_batch(gold_labels, pad_value=-100).squeeze(-1).squeeze(-1)
2121
+
2122
+ start_labels = _pad_batch(start_labels, pad_value=-100).squeeze(-1).squeeze(-1)
2123
+ end_labels = _pad_batch(end_labels, pad_value=-100).squeeze(-1).squeeze(-1)
2124
+
2125
+ return {
2126
+ "input_ids": input_ids,
2127
+ "attention_mask": attention_mask,
2128
+
2129
+ "all_spans": all_spans,
2130
+ "all_labels": all_labels,
2131
+
2132
+ "gold_spans": gold_spans,
2133
+ "gold_labels": gold_labels,
2134
+
2135
+ "start_labels": start_labels,
2136
+ "end_labels": end_labels,
2137
+
2138
+ "gold_entities": gold_entities,
2139
+ }
2140
+
2141
+ # %% [code]
2142
+ def shift_bidx(spans, batch_idx):
2143
+ shifted = []
2144
+ for bidx, ent in spans:
2145
+ new_bidx = bidx + batch_idx * batch_size
2146
+ shifted.append((new_bidx, ent))
2147
+ return shifted
2148
+
2149
+ def refactor_entities(entities, save_dict):
2150
+ i, c = [], []
2151
+ for bidx, (ids, lb) in entities:
2152
+ if (bidx, ids) not in i:
2153
+ i.append((bidx, ids))
2154
+
2155
+ if (bidx, (ids, lb)) not in c:
2156
+ c.append((bidx, (ids, lb)))
2157
+
2158
+ save_dict['Ent-I'].extend(i)
2159
+ save_dict['Ent-C'].extend(c)
2160
+
2161
+ def check_spans(
2162
+ all_spans, # (N, 2)
2163
+ all_labels, # (N,)
2164
+ ensemble_start_logits, # (L, C)
2165
+ ensemble_end_logits, # (L, C)
2166
+ expand_dict,
2167
+ max_expand=10
2168
+ ):
2169
+ """
2170
+ Check minimum expansion radius required to achieve recall = 1.
2171
+
2172
+ Args:
2173
+ all_spans: gold spans
2174
+ all_labels: gold labels
2175
+ ensemble_start_logits: (L, C)
2176
+ ensemble_end_logits: (L, C)
2177
+ expand_dict:
2178
+ dict like:
2179
+ {
2180
+ 0: count,
2181
+ 1: count,
2182
+ ...
2183
+ }
2184
+
2185
+ Return:
2186
+ matched_expand_k
2187
+ """
2188
+
2189
+ # =========================================================
2190
+ # Gold spans
2191
+ # =========================================================
2192
+
2193
+ gold_set = set()
2194
+
2195
+ valid = all_labels > 0
2196
+
2197
+ valid_idxes = valid.nonzero(as_tuple=False).squeeze(-1)
2198
+
2199
+ for idx in valid_idxes:
2200
+
2201
+ s = int(all_spans[idx, 0])
2202
+ e = int(all_spans[idx, 1])
2203
+
2204
+ gold_set.add((s, e))
2205
+
2206
+ # no gold
2207
+ if len(gold_set) == 0:
2208
+ expand_dict[0] += 1
2209
+ return 0
2210
+
2211
+ # =========================================================
2212
+ # Decode base spans
2213
+ # =========================================================
2214
+
2215
+ start_labels = ensemble_start_logits.argmax(dim=-1) # (L,)
2216
+ end_labels = ensemble_end_logits.argmax(dim=-1) # (L,)
2217
+
2218
+ L = start_labels.shape[0]
2219
+
2220
+ base_pred = []
2221
+
2222
+ used_start = set()
2223
+ used_end = set()
2224
+
2225
+ for s in range(L):
2226
+
2227
+ s_label = start_labels[s].item()
2228
+
2229
+ if s_label == 0:
2230
+ continue
2231
+
2232
+ if s in used_start:
2233
+ continue
2234
+
2235
+ nearest_e = None
2236
+
2237
+ for e in range(s, L):
2238
+
2239
+ if e in used_end:
2240
+ continue
2241
+
2242
+ e_label = end_labels[e].item()
2243
+
2244
+ if e_label == s_label:
2245
+ nearest_e = e
2246
+ break
2247
+
2248
+ if nearest_e is None:
2249
+ continue
2250
+
2251
+ used_start.add(s)
2252
+ used_end.add(nearest_e)
2253
+
2254
+ base_pred.append((s, nearest_e))
2255
+
2256
+ # =========================================================
2257
+ # Try expansion radius
2258
+ # =========================================================
2259
+
2260
+ for k in range(max_expand + 1):
2261
+
2262
+ pred_set = set()
2263
+
2264
+ for s, e in base_pred:
2265
+
2266
+ for ds in range(-k, k + 1):
2267
+ for de in range(-k, k + 1):
2268
+
2269
+ ns = s + ds
2270
+ ne = e + de
2271
+
2272
+ if ns > 0 and ne > 0 and ns <= ne:
2273
+ pred_set.add((ns, ne))
2274
+
2275
+ # recall = 1
2276
+ if gold_set.issubset(pred_set):
2277
+
2278
+ expand_dict[k] += 1
2279
+ return k
2280
+
2281
+ # =========================================================
2282
+ # cannot recover
2283
+ # =========================================================
2284
+
2285
+ expand_dict[-1] = expand_dict.get(-1, 0) + 1
2286
+
2287
+ return -1
2288
+
2289
+ def test(network, state_dicts, test_loader, eval_fn, analyzer, device, id2label, tokenizer):
2290
+ if torch.cuda.device_count() > 1:
2291
+ network = DataParallelProxy(network)
2292
+ network = network.to(device)
2293
+ network.eval()
2294
+
2295
+ eval_types = ['Ent-I', 'Ent-C']
2296
+
2297
+ all_pred = {eval_type: [] for eval_type in eval_types}
2298
+ all_gold = {eval_type: [] for eval_type in eval_types}
2299
+
2300
+ list_input_ids = []
2301
+ expand_dict = {i: 0 for i in range(13)}
2302
+
2303
+ with torch.no_grad():
2304
+ for batch_idx, batch in enumerate(test_loader):
2305
+ input_ids = batch['input_ids'].to(device)
2306
+ attention_mask = batch['attention_mask'].to(device)
2307
+ all_spans = batch['all_spans'].to(device)
2308
+ all_labels = batch['all_labels'].to(device)
2309
+ gold_entities = batch['gold_entities']
2310
+
2311
+ B, _, _ = input_ids.shape
2312
+ list_input_ids.extend(input_ids.reshape(B, -1).tolist())
2313
+
2314
+ list_hidden_states = []
2315
+ list_start_logits = []
2316
+ list_end_logits = []
2317
+ list_scores = []
2318
+ for sd in state_dicts:
2319
+ if torch.cuda.device_count() > 1:
2320
+ network.module.load_state_dict(sd)
2321
+ else:
2322
+ network.load_state_dict(sd)
2323
+
2324
+ hidden_states, attention_mask = network.encode(input_ids, attention_mask)
2325
+ start_logits, end_logits = network.get_logits(hidden_states)
2326
+ list_hidden_states.append(hidden_states)
2327
+ list_start_logits.append(start_logits)
2328
+ list_end_logits.append(end_logits)
2329
+
2330
+ ensemble_start_logits = torch.stack(list_start_logits, dim=0).mean(dim=0)
2331
+ ensemble_end_logits = torch.stack(list_end_logits, dim=0).mean(dim=0)
2332
+ spans, pred_labels = extract_spans_and_labels(ensemble_start_logits, ensemble_end_logits)
2333
+ B, N, _ = spans.shape
2334
+ r = torch.ones((B, N), device=spans.device, dtype=torch.long) * network.max_r
2335
+ expanded_spans = expand_spans(spans, r, attention_mask)
2336
+
2337
+ for sd, hidden_states in zip(state_dicts, list_hidden_states):
2338
+ if torch.cuda.device_count() > 1:
2339
+ network.module.load_state_dict(sd)
2340
+ else:
2341
+ network.load_state_dict(sd)
2342
+ expanded_span_reprs = get_span_reprs(hidden_states, expanded_spans)
2343
+ expanded_span_scores = network.get_scores(expanded_span_reprs)
2344
+ list_scores.append(expanded_span_scores)
2345
+
2346
+ ensemble_scores = torch.stack(list_scores, dim=0).mean(dim=0)
2347
+ pred_spans = select_best_spans(expanded_spans, ensemble_scores)
2348
+
2349
+ pred_entities = extract_entities(input_ids.reshape(B, -1), pred_spans, pred_labels, id2label)
2350
+ pred_entities = shift_bidx(pred_entities, batch_idx)
2351
+ refactor_entities(pred_entities, all_pred)
2352
+
2353
+ gold_entities = shift_bidx(gold_entities, batch_idx)
2354
+ refactor_entities(gold_entities, all_gold)
2355
+
2356
+ for b in range(B):
2357
+ check_spans(
2358
+ all_spans[b], # (N, 2)
2359
+ all_labels[b], # (N,)
2360
+ ensemble_start_logits[b], # (L, C)
2361
+ ensemble_end_logits[b], # (L, C)
2362
+ expand_dict,
2363
+ max_expand=10
2364
+ )
2365
+
2366
+ # ===== GLOBAL EVAL =====
2367
+ final_score = {}
2368
+ for eval_type in eval_types:
2369
+ score = eval_fn(list_to_tuple(all_pred[eval_type]), list_to_tuple(all_gold[eval_type]))
2370
+ final_score[eval_type] = score
2371
+
2372
+ analyze_result = analyzer.analyze(list_to_tuple(all_pred['Ent-I']), list_to_tuple(all_gold['Ent-I']))
2373
+
2374
+ # ===== PREDICT =====
2375
+ predictions = []
2376
+ for input_ids in list_input_ids:
2377
+ predictions.append([tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)])
2378
+ for bidx, (ids, lb) in all_pred['Ent-C']:
2379
+ predictions[bidx].append((tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True), lb))
2380
+
2381
+ return final_score, analyze_result, predictions, expand_dict
2382
+
2383
+ # %% [code]
2384
+ with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f:
2385
+ data_train = json.load(f)
2386
+
2387
+ with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f:
2388
+ data_test = json.load(f)
2389
+
2390
+ print('Train:', len(data_train))
2391
+ print('Test:', len(data_test))
2392
+
2393
+ # %% [code]
2394
+ entity_types = ['O'] + sorted(list(set([e['label'] for d in data_train + data_test for e in d['entities']])))
2395
+ # bio_entity_type = ['O'] + [f'{prefix}-{ent}' for ent in entity_types for prefix in ['B', 'I']]
2396
+ label2id = {l: i for i, l in enumerate(entity_types)}
2397
+ id2label = {i: l for l, i in label2id.items()}
2398
+
2399
+ # %% [code]
2400
+ zero_entities_idxes = []
2401
+ for idx, d in enumerate(data_train):
2402
+ if len(d['entities']) == 0:
2403
+ zero_entities_idxes.append(idx)
2404
+
2405
+ n_zero_entities_samples = len(zero_entities_idxes)
2406
+ n_has_entities_samples = len(data_train) - n_zero_entities_samples
2407
+
2408
+ random.seed(42)
2409
+ k = min(int(n_has_entities_samples * zero_entities_rate), len(zero_entities_idxes))
2410
+ sampled_zero_entities_idxes = random.sample(zero_entities_idxes, k)
2411
+
2412
+ new_data_train = []
2413
+ for idx, d in enumerate(data_train):
2414
+ if len(d['entities']) == 0:
2415
+ if idx in sampled_zero_entities_idxes:
2416
+ new_data_train.append(d)
2417
+ else:
2418
+ new_data_train.append(d)
2419
+ data_train = new_data_train
2420
+
2421
+ print('Train:', len(data_train))
2422
+
2423
+ # %% [code]
2424
+ if debug_only:
2425
+ data_train = data_train[:10]
2426
+ data_test = data_test[:10]
2427
+
2428
+ print('Train:', len(data_train))
2429
+ print('Test:', len(data_test))
2430
+
2431
+ # %% [code]
2432
+ tokenizer = AutoTokenizer.from_pretrained(backbone_model_name)
2433
+
2434
+ # %% [code]
2435
+ print('Experiment name:', state_dict_save_name)
2436
+
2437
+ # %% [code]
2438
+ if not test_only:
2439
+ full_idxes = np.array(range(len(data_train)))
2440
+ training_logs, best_models, last_models = [], [], []
2441
+ start_training_time = time.time()
2442
+ for seed in SEEDS:
2443
+ kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed)
2444
+ for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)):
2445
+ if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx:
2446
+ continue
2447
+ set_seed(seed)
2448
+
2449
+ train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx]
2450
+
2451
+ trainset = KLTNDataset(data_train, train_idxes, label2id, tokenizer, **train_memory_params)
2452
+ valset = KLTNDataset(data_train, val_idxes, label2id, tokenizer, **val_memory_params)
2453
+
2454
+ generator = torch.Generator()
2455
+ generator.manual_seed(seed)
2456
+ train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params)
2457
+ val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params)
2458
+
2459
+ my_model = IEModel(
2460
+ num_labels=len(label2id),
2461
+ **model_params
2462
+ )
2463
+ total_params = sum(p.numel() for p in my_model.parameters())
2464
+ print(f"Total params: {total_params:,}")
2465
+
2466
+ # optimizer, scheduler = configure_optimizers(my_model, optim_params, scheduler_params)
2467
+ encoder_params = set(map(id, my_model.encoder.parameters()))
2468
+ other_params = [
2469
+ p for p in my_model.parameters()
2470
+ if id(p) not in encoder_params
2471
+ ]
2472
+ optimizer = optim.AdamW([
2473
+ {"params": my_model.encoder.parameters(), "lr": 2e-5},
2474
+ {"params": other_params}
2475
+ ], lr=5e-4)
2476
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
2477
+
2478
+ loss_fn = CustomLoss(
2479
+ **loss_func_params
2480
+ )
2481
+ eval_fn = CustomEvalFn(**eval_func_params)
2482
+ trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}'
2483
+ trainer = Trainer(**trainer_params)
2484
+
2485
+ print(f'Start Training Fold {fold_idx}...')
2486
+ training_log, best_model, last_model = trainer.fit(
2487
+ my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, eval_fn,
2488
+ start_epoch=1, start_training_time=start_training_time, id2label=id2label
2489
+ )
2490
+
2491
+ training_logs.append(training_log)
2492
+ best_models.append(best_model)
2493
+ last_models.append(last_model)
2494
+
2495
+ # %% [code]
2496
+ def load_all_state_dicts(folder):
2497
+ files = []
2498
+
2499
+ for file in os.listdir(folder):
2500
+ if file.endswith(".pt") or file.endswith(".pth"):
2501
+ m = re.search(r"f(\d+)", file) # tìm f<số>
2502
+ if m:
2503
+ fold = int(m.group(1))
2504
+ files.append((fold, file))
2505
+
2506
+ # sort theo fold
2507
+ files.sort(key=lambda x: x[0])
2508
+
2509
+ state_dicts = []
2510
+ for fold, file in files:
2511
+ path = os.path.join(folder, file)
2512
+ print(f"Loading fold {fold}: {file}")
2513
+ state_dict = torch.load(path, map_location="cpu")
2514
+ state_dicts.append(state_dict)
2515
+
2516
+ return state_dicts
2517
+
2518
+ if test_only:
2519
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"])
2520
+ get_ipython().system('rm -rf .cache .gitattributes')
2521
+
2522
+ best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s")
2523
+ last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts")
2524
+
2525
+ # %% [code]
2526
+ os.makedirs(f'{checkpoints_dir}/results', exist_ok=True)
2527
+ testset = KLTNDataset(data_test, range(len(data_test)), label2id, tokenizer, **val_memory_params)
2528
+ generator = torch.Generator()
2529
+ test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params)
2530
+ eval_fn = CustomEvalFn(**eval_func_params)
2531
+ analyzer = SpanErrorAnalyzer()
2532
+ my_model = IEModel(
2533
+ num_labels=len(label2id),
2534
+ **model_params
2535
+ )
2536
+ total_params = sum(p.numel() for p in my_model.parameters())
2537
+ print(f"Total params: {total_params:,}")
2538
+
2539
+ # %% [code]
2540
+ start_time = time.time()
2541
+ result_test = None
2542
+ analyze_result = None
2543
+
2544
+ best_score, best_analyze_result, best_pred_test, expand_dict = test(my_model, best_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
2545
+ last_score, last_analyze_result, last_pred_test, _ = test(my_model, last_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
2546
+
2547
+ result_test = {"Best model": best_score, "Last model": last_score}
2548
+ analyze_result = {"Best model": best_analyze_result, "Last model": last_analyze_result}
2549
+ analyze_result_sumary = {"Best model": best_analyze_result['summary'], "Last model": last_analyze_result['summary']}
2550
+ pred_test = {"Best model": best_pred_test, "Last model": last_pred_test}
2551
+
2552
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f:
2553
+ json.dump(result_test, f, ensure_ascii=False, indent=2)
2554
+
2555
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_error_analyze_result.json", "w", encoding="utf-8") as f:
2556
+ json.dump(analyze_result, f, ensure_ascii=False, indent=2)
2557
+
2558
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_pred_test.json", "w", encoding="utf-8") as f:
2559
+ json.dump(pred_test, f, ensure_ascii=False, indent=2)
2560
+
2561
+ print('Test:', time.time() - start_time, 's --> Done!')
2562
+ print(json.dumps(analyze_result_sumary, ensure_ascii=False, indent=4))
2563
+
2564
+ # %% [code]
2565
+ expand_dict
2566
+
2567
+ # %% [code]
2568
+ expand_dict_sum = sum(list(expand_dict.values()))
2569
+ {key: value / expand_dict_sum for key, value in expand_dict.items()}
2570
+
2571
+ # %% [code]
2572
+ best_pred_test[:10]
2573
+
2574
+ # %% [code]
2575
+ last_pred_test[:10]
2576
+
2577
+ # %% [code]
2578
+ def dict_to_df(data):
2579
+ row_tuples = []
2580
+ row_values = []
2581
+
2582
+ metrics = ["precision", "recall", "f1"]
2583
+
2584
+ # Lấy model đầu tiên
2585
+ first_model = next(iter(data.values()))
2586
+
2587
+ # eval_keys
2588
+ eval_keys = list(first_model.keys())
2589
+
2590
+ for eval_key in eval_keys:
2591
+ row_tuples.append(eval_key)
2592
+ row = {}
2593
+
2594
+ for model_name, model_data in data.items():
2595
+ for metric in metrics:
2596
+ row[(model_name, metric)] = model_data[eval_key][metric]
2597
+
2598
+ row_values.append(row)
2599
+
2600
+ # ===== DataFrame =====
2601
+ df = pd.DataFrame(row_values)
2602
+
2603
+ # MultiIndex columns
2604
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
2605
+
2606
+ # Index
2607
+ df.index = pd.Index(row_tuples, name="evaluation")
2608
+
2609
+ # ===== Sort =====
2610
+ sort_keys = []
2611
+ if ("Best model", "f1") in df.columns:
2612
+ sort_keys.append(("Best model", "f1"))
2613
+ if ("Last model", "f1") in df.columns:
2614
+ sort_keys.append(("Last model", "f1"))
2615
+
2616
+ if sort_keys:
2617
+ df = df.sort_values(by=sort_keys, ascending=False)
2618
+
2619
+ return df
2620
+
2621
+ result_test_df = dict_to_df(result_test)
2622
+ result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx")
2623
+ result_test_df
2624
+
2625
+ # %% [code]
2626
+ key = ("Best model", "f1")
2627
+ result_test_df_best = result_test_df.sort_values(by=key, ascending=False).groupby(level="evaluation").head(1)
2628
+ result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx")
2629
+ result_test_df_best
2630
+
2631
+ # %% [code]
2632
+ def get_avg_best_score(logs):
2633
+ return float(np.mean([list(log.values())[-1]['best_score'] for log in logs]))
2634
+
2635
+ def get_avg_log(logs, epochs):
2636
+ avg_log = {}
2637
+
2638
+ for epoch in range(1, epochs + 1):
2639
+ val_score = 0.0
2640
+ train_loss = 0.0
2641
+ n_eval = 0
2642
+
2643
+ for idx in range(len(logs)):
2644
+ log = logs[idx].get(epoch, logs[idx].get(str(epoch)))
2645
+ if log is None:
2646
+ continue
2647
+
2648
+ val_score += log.get('val_score', 0.0)
2649
+ train_loss += log.get('train_loss', 0.0)
2650
+ n_eval += 1
2651
+
2652
+ if n_eval == 0:
2653
+ continue
2654
+
2655
+ avg_log[epoch] = {
2656
+ 'train_loss': train_loss / n_eval,
2657
+ 'val_score': val_score / n_eval if val_score != 0 else float('inf')
2658
+ }
2659
+
2660
+ return avg_log
2661
+
2662
+ def parse_label_key(label: str):
2663
+ try:
2664
+ first = float(label.split('_', 1)[0]) # số đầu: trước dấu _
2665
+ last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0])
2666
+ return first, last
2667
+ except:
2668
+ return (0, 0)
2669
+
2670
+ def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)):
2671
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
2672
+
2673
+ # ===== Plot Train Loss =====
2674
+ for name, log in logs_dict.items():
2675
+ epochs = sorted(log.keys())
2676
+ train_loss = [log[e]['train_loss'] for e in epochs]
2677
+ axes[0].plot(epochs, train_loss, label=name)
2678
+
2679
+ axes[0].set_xlabel('Epoch')
2680
+ axes[0].set_ylabel('Train Loss')
2681
+ axes[0].set_title('Training Loss')
2682
+ axes[0].grid(True)
2683
+
2684
+ # ===== Plot Validation Score =====
2685
+ for name, log in logs_dict.items():
2686
+ epochs = sorted(log.keys())
2687
+ val_score = [log[e]['val_score'] for e in epochs]
2688
+ axes[1].plot(epochs, val_score, label=name)
2689
+
2690
+ axes[1].set_xlabel('Epoch')
2691
+ axes[1].set_ylabel('Validation Score')
2692
+ axes[1].set_title('Validation Score')
2693
+ axes[1].grid(True)
2694
+
2695
+ # ===== Shared Legend =====
2696
+ handles, labels = axes[0].get_legend_handles_labels()
2697
+ pairs = list(zip(handles, labels))
2698
+ pairs_sorted = sorted(
2699
+ pairs,
2700
+ key=lambda x: parse_label_key(x[1])
2701
+ )
2702
+ handles_sorted, labels_sorted = zip(*pairs_sorted)
2703
+
2704
+ axes[0].legend(
2705
+ handles_sorted,
2706
+ labels_sorted,
2707
+ loc='center left',
2708
+ bbox_to_anchor=(1.01, 0.5),
2709
+ borderaxespad=0.
2710
+ )
2711
+
2712
+ plt.tight_layout(rect=[0, 0, 1, 1])
2713
+
2714
+ if save_path is not None:
2715
+ os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None
2716
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
2717
+
2718
+ plt.show()
2719
+
2720
+ # %% [code]
2721
+ if not test_only:
2722
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*entities*.json"])
2723
+ get_ipython().system('rm -rf .cache .gitattributes')
2724
+
2725
+ # %% [code]
2726
+ if not test_only:
2727
+ experiments = {}
2728
+ for experiment in os.listdir(pretrained_dir):
2729
+ if '.virtual_documents' in experiment:
2730
+ continue
2731
+ experiment_logs = []
2732
+ try:
2733
+ for seed in SEEDS:
2734
+ for fold_idx in range(nfolds):
2735
+ with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f:
2736
+ experiment_log = json.load(f)
2737
+ experiment_logs.append(experiment_log)
2738
+ except:
2739
+ pass
2740
+ experiments[experiment] = get_avg_log(experiment_logs, 1000)
2741
+ experiments[state_dict_save_name] = get_avg_log(training_logs, 1000)
2742
+
2743
+ # %% [code]
2744
+ if not test_only:
2745
+ score = get_avg_best_score(training_logs)
2746
+ state_dict_save_name, score
2747
+
2748
+ # %% [code]
2749
+ if not test_only:
2750
+ plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5))
2751
+
4.2_add_span_rerank_branch_4.5/lasts/4.2_add_span_rerank_branch_4.5_s26092004_f0_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9366440dbc8ba6e599af7db1fe997834b33d18987933e945063fa8b3b0476d2
3
+ size 554283006
4.2_add_span_rerank_branch_4.5/logs/4.2_add_span_rerank_branch_4.5_log_plot.jpg ADDED

Git LFS Details

  • SHA256: ecf1d538c6371c531581ca10c0f4bd8e187ddb55e0d843add7a38185541986f8
  • Pointer size: 131 Bytes
  • Size of remote file: 558 kB
4.2_add_span_rerank_branch_4.5/logs/4.2_add_span_rerank_branch_4.5_s26092004_f0_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 0.0005], "train_loss": 0.1468505859375, "total": 0.14686975964225824, "token_margin_loss": 0.05261319172722191, "span_loss": 0.09663219675796535, "start_margin": 0.024280324203465622, "end_margin": 0.02826299608719955}, "2": {"lr": [1.988303923565381e-05, 0.0004969282409784868], "train_loss": 0.11334228515625, "total": 0.11333147009502516, "token_margin_loss": 0.04049049748462828, "span_loss": 0.07022079373951928, "start_margin": 0.019179709334823925, "end_margin": 0.02199203465623253}, "3": {"lr": [1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 0.1005859375, "total": 0.10061486864169927, "token_margin_loss": 0.03818474007825601, "span_loss": 0.06005449972051426, "start_margin": 0.01807923420905534, "end_margin": 0.02041992733370598}, "4": {"lr": [1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 0.091064453125, "total": 0.09104248183342649, "token_margin_loss": 0.036053661263275576, "span_loss": 0.051669927333705984, "start_margin": 0.01755519843487982, "end_margin": 0.01916224147568474}, "5": {"lr": [1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 0.08251953125, "total": 0.08251816657350475, "token_margin_loss": 0.035075461151481274, "span_loss": 0.0451369480156512, "start_margin": 0.017162171604248183, "end_margin": 0.018323784237003912}, "6": {"lr": [1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 0.07568359375, "total": 0.07567076579094466, "token_margin_loss": 0.034219536053661265, "span_loss": 0.03902319731693683, "start_margin": 0.016769144773616546, "end_margin": 0.01759886808272778}, "7": {"lr": [1.60839598967785e-05, 0.00039715242044697206], "train_loss": 0.0699462890625, "total": 0.06997624371157071, "token_margin_loss": 0.03364309670206819, "span_loss": 0.03380030743432085, "start_margin": 0.016463457238680826, "end_margin": 0.01731938233650084}, "8": {"lr": [1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 0.06414794921875, "total": 0.06417691447736165, "token_margin_loss": 0.032752235885969816, "span_loss": 0.03002724986025713, "start_margin": 0.01611410005589715, "end_margin": 0.016900153717160426, "val_score": 0.6863627350530408, "best_score": 0.6863627350530408, "new_best_model": true, "precision": 0.7682875104263381, "recall": 0.6226203064307612, "f1": 0.6863627350530408}, "9": {"lr": [1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 0.0589599609375, "total": 0.058971492453884854, "token_margin_loss": 0.03236794298490777, "span_loss": 0.025171185019564002, "start_margin": 0.01585208216880939, "end_margin": 0.016550796534376747, "val_score": 0.6859065883350041, "best_score": 0.6863627350530408, "new_best_model": false, "precision": 0.7664244548681001, "recall": 0.6230380104041229, "f1": 0.6859065883350041}, "10": {"lr": [1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 0.05474853515625, "total": 0.05474427054220235, "token_margin_loss": 0.031424678591391836, "span_loss": 0.021974566797093347, "start_margin": 0.015467789267747344, "end_margin": 0.01621017328116266, "val_score": 0.686874862767166, "best_score": 0.686874862767166, "new_best_model": true, "precision": 0.76764898675689, "recall": 0.6239278347187776, "f1": 0.686874862767166}, "11": {"lr": [1.0500000000000003e-05, 0.0002505], "train_loss": 0.050750732421875, "total": 0.05076159865846842, "token_margin_loss": 0.031145192845164895, "span_loss": 0.019267048630519843, "start_margin": 0.01526690888764673, "end_margin": 0.01590448574622694, "val_score": 0.6872180862141289, "best_score": 0.6872180862141289, "new_best_model": true, "precision": 0.766240335643193, "recall": 0.6254058784056634, "f1": 0.6872180862141289}, "12": {"lr": [9.013872582117811e-06, 0.00021146960097246246], "train_loss": 0.046966552734375, "total": 0.046953605366126326, "token_margin_loss": 0.030568753493571826, "span_loss": 0.016524594745667972, "start_margin": 0.01500489100055897, "end_margin": 0.015694871436556734, "val_score": 0.6905636463106959, "best_score": 0.6905636463106959, "new_best_model": true, "precision": 0.7693898045943346, "recall": 0.6288075361008783, "f1": 0.6905636463106959}, "13": {"lr": [7.564338553438001e-06, 0.00017340025990345064], "train_loss": 0.044830322265625, "total": 0.044822526551145894, "token_margin_loss": 0.030236864169927335, "span_loss": 0.015048560648406931, "start_margin": 0.01477780883174958, "end_margin": 0.015467789267747344, "val_score": 0.6903134720088675, "best_score": 0.6905636463106959, "new_best_model": false, "precision": 0.7691587246053712, "recall": 0.6284828908420117, "f1": 0.6903134720088675}, "14": {"lr": [6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 0.041961669921875, "total": 0.04195779765231973, "token_margin_loss": 0.029380939072107322, "span_loss": 0.012882546115148127, "start_margin": 0.014358580212409168, "end_margin": 0.014987423141419787, "val_score": 0.6894675601432522, "best_score": 0.6905636463106959, "new_best_model": false, "precision": 0.7691616105349628, "recall": 0.6270536907329344, "f1": 0.6894675601432522}, "15": {"lr": [4.916040103221507e-06, 0.00010384757955302797], "train_loss": 0.039398193359375, "total": 0.03940749021799888, "token_margin_loss": 0.029206260480715483, "span_loss": 0.01079513694801565, "start_margin": 0.014210103409726104, "end_margin": 0.014961221352711012, "val_score": 0.6902389874528908, "best_score": 0.6905636463106959, "new_best_model": false, "precision": 0.7696662485785113, "recall": 0.6279449177801002, "f1": 0.6902389874528908}, "16": {"lr": [3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 0.0382080078125, "total": 0.038219675796534375, "token_margin_loss": 0.02869969256567915, "span_loss": 0.009965413638904417, "start_margin": 0.01392188373392957, "end_margin": 0.014611864169927333, "val_score": 0.6909374894379251, "best_score": 0.6909374894379251, "new_best_model": true, "precision": 0.7695734193089366, "recall": 0.6293325769979943, "f1": 0.6909374894379251}, "17": {"lr": [2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 0.036468505859375, "total": 0.03647288988261599, "token_margin_loss": 0.028315399664617106, "span_loss": 0.008567984907769704, "start_margin": 0.013703535494689771, "end_margin": 0.014428451648965902, "val_score": 0.6899378747368216, "best_score": 0.6909374894379251, "new_best_model": false, "precision": 0.7691208417645939, "recall": 0.6279288525760165, "f1": 0.6899378747368216}, "18": {"lr": [2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 0.0355224609375, "total": 0.03551215762996087, "token_margin_loss": 0.028053381777529345, "span_loss": 0.007423840134153158, "start_margin": 0.013546324762437115, "end_margin": 0.014253773057574064, "val_score": 0.6889585967473153, "best_score": 0.6909374894379251, "new_best_model": false, "precision": 0.7676446066334408, "recall": 0.6272630717039508, "f1": 0.6889585967473153}}
4.2_add_span_rerank_branch_4.5/r1s/4.2_add_span_rerank_branch_4.5_s26092004_f0_r1_vs0.69094_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05dcaf378904d9084f13eec451dd3fab35306a10332baa3949c0a25374ad9b5c
3
+ size 554284870
4.2_add_span_rerank_branch_4.5/results/4.2_add_span_rerank_branch_4.5_error_analyze_result.json ADDED
The diff for this file is too large to render. See raw diff
 
4.2_add_span_rerank_branch_4.5/results/4.2_add_span_rerank_branch_4.5_pred_test.json ADDED
The diff for this file is too large to render. See raw diff
 
4.2_add_span_rerank_branch_4.5/results/4.2_add_span_rerank_branch_4.5_test.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Best model": {
3
+ "Ent-I": {
4
+ "precision": 0.7962769051765022,
5
+ "recall": 0.636118598382158,
6
+ "f1": 0.7072439756342697
7
+ },
8
+ "Ent-C": {
9
+ "precision": 0.6529948599594534,
10
+ "recall": 0.5780481010301995,
11
+ "f1": 0.6132400698873992
12
+ }
13
+ },
14
+ "Last model": {
15
+ "Ent-I": {
16
+ "precision": 0.7959656851370321,
17
+ "recall": 0.6381633980847308,
18
+ "f1": 0.7083827652429374
19
+ },
20
+ "Ent-C": {
21
+ "precision": 0.6533472803340445,
22
+ "recall": 0.5799981428168075,
23
+ "f1": 0.6144916079837778
24
+ }
25
+ }
26
+ }
4.2_add_span_rerank_branch_4.5/results/4.2_add_span_rerank_branch_4.5_test_df.xlsx ADDED
Binary file (5.29 kB). View file
 
4.2_add_span_rerank_branch_4.5/results/4.2_add_span_rerank_branch_4.5_test_df_best.xlsx ADDED
Binary file (5.29 kB). View file