SS3M commited on
Commit
0fa87ff
·
verified ·
1 Parent(s): fcca8cb

Upload 4.2_add_span_rerank_branch_phoner_contrastive_4.5's state dict

Browse files
.gitattributes CHANGED
@@ -97,3 +97,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
97
  4.2_2_clone_4.6/logs/4.2_2_clone_4.6_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
98
  4.1_entities_no_ce_phoner_4.2/logs/4.1_entities_no_ce_phoner_4.2_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
99
  4.2_add_span_rerank_branch_phoner_4.5/logs/4.2_add_span_rerank_branch_phoner_4.5_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
 
 
97
  4.2_2_clone_4.6/logs/4.2_2_clone_4.6_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
98
  4.1_entities_no_ce_phoner_4.2/logs/4.1_entities_no_ce_phoner_4.2_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
99
  4.2_add_span_rerank_branch_phoner_4.5/logs/4.2_add_span_rerank_branch_phoner_4.5_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
100
+ 4.2_add_span_rerank_branch_phoner_contrastive_4.5/logs/4.2_add_span_rerank_branch_phoner_contrastive_4.5_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
4.2_add_span_rerank_branch_phoner_contrastive_4.5/4.2_add_span_rerank_branch_phoner_contrastive_4.5.py ADDED
@@ -0,0 +1,2744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [code]
2
+ get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch] pytorch-crf')
3
+
4
+ # %% [code]
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
12
+ import torch.nn.functional as F
13
+ import albumentations as albu
14
+ from transformers import AutoTokenizer, AutoModel
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from positional_encodings.torch_encodings import PositionalEncoding1D
18
+ from torchcrf import CRF
19
+
20
+ from sklearn.metrics import f1_score
21
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
22
+ from scipy.spatial.transform import Rotation as R
23
+ from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold
24
+ from sklearn.metrics import precision_recall_fscore_support
25
+ from timm.utils import ModelEmaV3
26
+ import timm
27
+
28
+ import os
29
+ import gc
30
+ import json
31
+ from pathlib import Path
32
+ import pickle
33
+ from tqdm.auto import tqdm
34
+ import copy
35
+ import numpy as np
36
+ import pandas as pd
37
+ import polars as pl
38
+ from PIL import Image
39
+ import time
40
+ from tqdm import tqdm
41
+ from matplotlib import pyplot as plt
42
+ import seaborn as sns
43
+ from multiprocessing import Manager as MemoryManager
44
+ from functools import lru_cache
45
+ import shutil
46
+ import glob
47
+ import cv2
48
+ import random
49
+ import re
50
+ import joblib
51
+ import math
52
+ from huggingface_hub import HfApi, snapshot_download
53
+ import evaluate
54
+ from underthesea import word_tokenize as vi_tokenize_tool
55
+ import spacy
56
+ en_tokenize_tool = spacy.load("en_core_web_sm")
57
+ from collections import defaultdict, Counter
58
+
59
+ # %% [code]
60
+ # Global config
61
+ SEEDS = [26092004]
62
+ topk = 1
63
+ nfolds = 5
64
+ only_fold_idx = 0
65
+ test_only = 0
66
+ debug_only = 0
67
+
68
+ # Config thư mục
69
+ dataset = 'phoner' # conll003, ontonotes, phoner, vietbio, vietmed, vimed, kltn/only_entities, kltn/raw
70
+ root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' ## Thư mục chứa file train, val, test
71
+ train_dir = f'{root_dir}'
72
+ # val_dir = f'{root_dir}/val'
73
+ test_dir = f'{root_dir}'
74
+
75
+ # Config checkpoints
76
+
77
+ # Config training
78
+ epochs = 18 if not debug_only else 2
79
+ batch_size = 32
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+ # # Thêm biến toàn cục nào đó vào đây
82
+ repo_name = 'SS3M/kltn-experiments'
83
+ state_dict_save_name = "4.2_add_span_rerank_branch_phoner_contrastive_4.5"
84
+ checkpoints_dir = state_dict_save_name
85
+ pretrained_dir = "/kaggle/working"
86
+ os.makedirs(f'{checkpoints_dir}', exist_ok=True)
87
+
88
+ backbone_model_name = "bert-base-uncased" if dataset in ["conll003", "ontonotes"] else "vinai/phobert-base"
89
+ word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == dataset in ["conll003", "ontonotes"] else vi_tokenize_tool(text)
90
+ max_len_dict = {
91
+ 'kltn/raw': 256,
92
+ 'kltn/only_entities': 68,
93
+ 'conll003': 46,
94
+ 'ontonotes': 61,
95
+ 'phoner': 68,
96
+ 'vietbio': 125,
97
+ 'vietmed': 36,
98
+ 'vimed': 100,
99
+ }
100
+ zero_entities_rate_dict = {
101
+ 'kltn/raw': 1000,
102
+ 'kltn/only_entities': 0.2,
103
+ 'conll003': 1000, # mean keep all zero-entities samples
104
+ 'ontonotes': 1000,
105
+ 'phoner': 1000,
106
+ 'vietbio': 1000,
107
+ 'vietmed': 1000,
108
+ 'vimed': 1000,
109
+ }
110
+
111
+ max_len = max_len_dict[dataset]
112
+ max_n_parts = 1
113
+ max_span_len = 10
114
+ zero_entities_rate = zero_entities_rate_dict[dataset]
115
+
116
+ # Trainer
117
+ trainer_params = {
118
+ "training_time": "00:11:30:00",
119
+ "eval_mode": "max",
120
+ "topk": topk,
121
+ "save_name": state_dict_save_name,
122
+ "save_best": True,
123
+ "save_last": True,
124
+ "device": device,
125
+ "logging": True,
126
+ "logging_file": True,
127
+ "checkpoints_dir": checkpoints_dir,
128
+ "early_stopping": 30,
129
+ "eval_from_ratio": 0.4,
130
+ "eval_every": 1,
131
+ "schedule_in_step": False,
132
+ "use_ema": True,
133
+ "ema_from_ratio": 0.3,
134
+ "ema_decay": 0.9995,
135
+ "max_grad_norm": 200.0,
136
+ "return_best": True,
137
+ "return_last": True,
138
+ }
139
+
140
+ # Memory
141
+ train_memory_params = {
142
+ 'max_len': max_len,
143
+ 'max_n_parts': max_n_parts,
144
+ }
145
+ val_memory_params = {
146
+ 'max_len': max_len,
147
+ 'max_n_parts': max_n_parts,
148
+ }
149
+
150
+ # Data Loader
151
+ def seed_worker(worker_id):
152
+ worker_seed = torch.initial_seed() % 2**32
153
+ np.random.seed(worker_seed)
154
+ random.seed(worker_seed)
155
+
156
+ train_loader_params = {
157
+ 'batch_size': batch_size,
158
+ 'shuffle': True,
159
+ 'pin_memory':True,
160
+ 'num_workers': 2,
161
+ 'drop_last': False,
162
+ 'worker_init_fn': seed_worker,
163
+ 'persistent_workers': False,
164
+ }
165
+ val_loader_params = {
166
+ 'batch_size': batch_size,
167
+ 'shuffle': False,
168
+ 'pin_memory':True,
169
+ 'num_workers': 1,
170
+ 'drop_last': False,
171
+ 'worker_init_fn': seed_worker,
172
+ 'persistent_workers': False,
173
+ }
174
+
175
+ # Model
176
+ model_params = {
177
+ 'backbone_model_name': backbone_model_name,
178
+ 'max_r': 6,
179
+ }
180
+
181
+ # Loss Func
182
+ loss_func_params = {
183
+ 'lambda_span': 1.0,
184
+ 'lambda_margin': 1.0,
185
+ 'margin': 0.5,
186
+ }
187
+ eval_func_params = {}
188
+
189
+ # Optim
190
+ optim_params = {
191
+ 'name': 'AdamW',
192
+ 'lr': 1e-4,
193
+ 'weight_decay': 1e-4,
194
+ }
195
+ scheduler_params = {
196
+ 'name': 'CosineAnnealingLR',
197
+ 'T_max': 20, # Số epoch để hoàn thành một chu kỳ giảm LR
198
+ 'eta_min': 1e-6 # Learning rate nhỏ nhất trong chu kỳ
199
+ }
200
+
201
+ # %% [code]
202
+ def set_seed(seed=42):
203
+ random.seed(seed)
204
+ np.random.seed(seed)
205
+ torch.manual_seed(seed)
206
+ torch.cuda.manual_seed(seed)
207
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
208
+ torch.use_deterministic_algorithms(False)
209
+ torch.backends.cudnn.deterministic = True
210
+ torch.backends.cudnn.benchmark = False
211
+ os.environ['PYTHONHASHSEED'] = str(seed)
212
+
213
+ # %% [code]
214
+ class CustomLoss(nn.Module):
215
+ def __init__(
216
+ self,
217
+ lambda_margin=1.0,
218
+ lambda_span=1.0,
219
+ lambda_contrastive=1.0,
220
+ margin=1.0,
221
+ temperature=0.07
222
+ ):
223
+ super().__init__()
224
+
225
+ self.lambda_margin = lambda_margin
226
+ self.lambda_span = lambda_span
227
+ self.lambda_contrastive = lambda_contrastive
228
+
229
+ self.margin = margin
230
+ self.temperature = temperature
231
+
232
+ def class_loss(self, logits, labels):
233
+ valid_mask = labels != -100
234
+ logits, labels = logits[valid_mask], labels[valid_mask]
235
+
236
+ if labels.numel() == 0:
237
+ return logits.new_tensor(0.0)
238
+
239
+ pos_logits = logits.gather(1, labels.unsqueeze(-1)).squeeze(-1)
240
+
241
+ neg_logits = logits.clone()
242
+ neg_logits.scatter_(1, labels.unsqueeze(-1), float("-inf"))
243
+
244
+ hardest_neg = neg_logits.max(dim=-1).values
245
+
246
+ loss = F.relu(self.margin - pos_logits + hardest_neg)
247
+
248
+ return loss.mean()
249
+
250
+ def span_loss(self, span_scores, labels):
251
+ """
252
+ span_scores: (B, N, M)
253
+ labels: (B, N, M)
254
+
255
+ labels:
256
+ -100 -> ignore
257
+ 0 -> negative
258
+ >0 -> positive
259
+ """
260
+
261
+ if span_scores.numel() == 0:
262
+ return span_scores.new_tensor(0.0)
263
+
264
+ pos_mask = labels > 0
265
+ neg_mask = labels == 0
266
+
267
+ has_pos = pos_mask.any(dim=-1)
268
+
269
+ best_pos = (
270
+ span_scores
271
+ .masked_fill(~pos_mask, float("-inf"))
272
+ .max(dim=-1)
273
+ .values
274
+ )
275
+
276
+ best_neg = (
277
+ span_scores
278
+ .masked_fill(~neg_mask, float("-inf"))
279
+ .max(dim=-1)
280
+ .values
281
+ )
282
+
283
+ best_neg = torch.where(
284
+ neg_mask.any(dim=-1),
285
+ best_neg,
286
+ torch.zeros_like(best_neg)
287
+ )
288
+
289
+ loss = F.relu(self.margin - best_pos + best_neg)
290
+
291
+ loss = loss[has_pos]
292
+
293
+ if loss.numel() == 0:
294
+ return span_scores.new_tensor(0.0)
295
+
296
+ return loss.mean()
297
+
298
+ def contrastive_loss(self, span_scores, labels):
299
+ """
300
+ span_scores: (B, N, M)
301
+ labels: (B, N, M)
302
+
303
+ labels:
304
+ -100 -> ignore
305
+ 0 -> negative
306
+ >0 -> positive
307
+ """
308
+
309
+ device = span_scores.device
310
+
311
+ pos_mask = labels > 0
312
+ neg_mask = labels == 0
313
+
314
+ losses = []
315
+
316
+ B, N, M = span_scores.shape
317
+
318
+ for b in range(B):
319
+ for n in range(N):
320
+
321
+ pos_scores = span_scores[b, n][pos_mask[b, n]]
322
+ neg_scores = span_scores[b, n][neg_mask[b, n]]
323
+
324
+ # không có positive
325
+ if pos_scores.numel() == 0:
326
+ continue
327
+
328
+ logits = torch.cat(
329
+ [pos_scores, neg_scores],
330
+ dim=0
331
+ )
332
+
333
+ logits = logits / self.temperature
334
+
335
+ # ===== multi-positive labels =====
336
+
337
+ labels_vec = torch.zeros_like(logits)
338
+ labels_vec[:pos_scores.size(0)] = 1.0
339
+
340
+ # ===== log-softmax =====
341
+
342
+ log_prob = logits - torch.logsumexp(
343
+ logits,
344
+ dim=0,
345
+ keepdim=True
346
+ )
347
+
348
+ loss = -(
349
+ labels_vec * log_prob
350
+ ).sum() / labels_vec.sum().clamp(min=1)
351
+
352
+ losses.append(loss)
353
+
354
+ if len(losses) == 0:
355
+ return span_scores.new_tensor(0.0)
356
+
357
+ return torch.stack(losses).mean()
358
+
359
+ def forward(
360
+ self,
361
+ start_logits,
362
+ start_labels,
363
+
364
+ end_logits,
365
+ end_labels,
366
+
367
+ span_scores,
368
+ labels
369
+ ):
370
+
371
+ B, L, C = start_logits.shape
372
+
373
+ # ===== token margin =====
374
+
375
+ start_margin = self.class_loss(
376
+ start_logits.view(-1, C),
377
+ start_labels.view(-1)
378
+ )
379
+
380
+ end_margin = self.class_loss(
381
+ end_logits.view(-1, C),
382
+ end_labels.view(-1)
383
+ )
384
+
385
+ token_margin_loss = start_margin + end_margin
386
+
387
+ # ===== span margin =====
388
+
389
+ span_margin_loss = self.span_loss(
390
+ span_scores,
391
+ labels
392
+ )
393
+
394
+ # ===== contrastive =====
395
+
396
+ contrastive_loss = self.contrastive_loss(
397
+ span_scores,
398
+ labels
399
+ )
400
+
401
+ # ===== total =====
402
+
403
+ total_loss = (
404
+ self.lambda_margin * token_margin_loss +
405
+ self.lambda_span * span_margin_loss +
406
+ self.lambda_contrastive * contrastive_loss
407
+ )
408
+
409
+ return {
410
+ "total": total_loss,
411
+
412
+ "token_margin_loss": token_margin_loss,
413
+ "span_margin_loss": span_margin_loss,
414
+ "contrastive_loss": contrastive_loss,
415
+
416
+ "start_margin": start_margin,
417
+ "end_margin": end_margin,
418
+ }
419
+
420
+ # %% [code]
421
+ ## Viết eval_fn vào đây
422
+
423
+ # Bỏ hết eval_fn và trọng số vào đây
424
+ class CustomEvalFn(nn.Module):
425
+ def __init__(self):
426
+ super().__init__()
427
+
428
+ def compute_f1(self, tp, fp, fn):
429
+ precision = tp / (tp + fp + 1e-8)
430
+ recall = tp / (tp + fn + 1e-8)
431
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
432
+ return precision, recall, f1
433
+
434
+ def forward(self, pred, gold):
435
+ pred_set = set(pred)
436
+ gold_set = set(gold)
437
+
438
+ tp = len(pred_set & gold_set)
439
+ fp = len(pred_set - gold_set)
440
+ fn = len(gold_set - pred_set)
441
+
442
+ precision, recall, f1 = self.compute_f1(tp, fp, fn)
443
+
444
+ return {
445
+ f"precision": precision,
446
+ f"recall": recall,
447
+ f"f1": f1,
448
+ }
449
+
450
+ class SpanErrorAnalyzer:
451
+ def __init__(self, pad_token_id=0):
452
+ self.pad_token_id = pad_token_id
453
+
454
+ # ===== helper =====
455
+ def _to_set(self, data):
456
+ """
457
+ data: list of (b, tuple(ids))
458
+ -> dict[b] = set(tuple(ids))
459
+ """
460
+ res = defaultdict(set)
461
+ for b, ids in data:
462
+ ids = tuple([i for i in ids if i != self.pad_token_id])
463
+ if len(ids) > 0:
464
+ res[b].add(ids)
465
+ return res
466
+
467
+ def _iou(self, a, b):
468
+ """
469
+ a, b: tuple(ids)
470
+ """
471
+ set_a, set_b = set(a), set(b)
472
+ inter = len(set_a & set_b)
473
+ union = len(set_a | set_b)
474
+ if union == 0:
475
+ return 0.0
476
+ return inter / union
477
+
478
+ def _boundary_error(self, pred, gold):
479
+ """
480
+ đo lệch boundary dựa trên overlap prefix/suffix
481
+ """
482
+ # left match
483
+ left = 0
484
+ for i in range(min(len(pred), len(gold))):
485
+ if pred[i] == gold[i]:
486
+ left += 1
487
+ else:
488
+ break
489
+
490
+ # right match
491
+ right = 0
492
+ for i in range(1, min(len(pred), len(gold)) + 1):
493
+ if pred[-i] == gold[-i]:
494
+ right += 1
495
+ else:
496
+ break
497
+
498
+ return {
499
+ "left_match": left,
500
+ "right_match": right,
501
+ "pred_len": len(pred),
502
+ "gold_len": len(gold),
503
+ }
504
+
505
+ # ===== main =====
506
+ def analyze(self, preds, golds):
507
+ pred_map = self._to_set(preds)
508
+ gold_map = self._to_set(golds)
509
+
510
+ all_batches = set(pred_map.keys()) | set(gold_map.keys())
511
+
512
+ stats = Counter()
513
+
514
+ detailed_errors = []
515
+
516
+ for b in all_batches:
517
+ pset = pred_map.get(b, set())
518
+ gset = gold_map.get(b, set())
519
+
520
+ matched_gold = set()
521
+
522
+ # ===== check predictions =====
523
+ for p in pset:
524
+ if p in gset:
525
+ stats["exact_match"] += 1
526
+ matched_gold.add(p)
527
+ else:
528
+ # tìm gold gần nhất
529
+ best_iou = 0
530
+ best_g = None
531
+
532
+ for g in gset:
533
+ iou = self._iou(p, g)
534
+ if iou > best_iou:
535
+ best_iou = iou
536
+ best_g = g
537
+
538
+ if best_iou > 0:
539
+ stats["partial_match"] += 1
540
+
541
+ boundary = self._boundary_error(p, best_g)
542
+
543
+ detailed_errors.append({
544
+ "type": "boundary_error",
545
+ "batch": b,
546
+ "pred": p,
547
+ "gold": best_g,
548
+ "iou": best_iou,
549
+ **boundary
550
+ })
551
+ else:
552
+ if b not in gold_map:
553
+ stats["no_event_sample"] += 1
554
+ err_type = "no_event_sample"
555
+ else:
556
+ stats["completely_wrong"] += 1
557
+ err_type = "completely_wrong"
558
+
559
+ detailed_errors.append({
560
+ "type": err_type,
561
+ "batch": b,
562
+ "pred": p
563
+ })
564
+
565
+ # ===== check missing =====
566
+ for g in gset:
567
+ if g not in matched_gold:
568
+ # check if any pred overlaps
569
+ overlap = any(self._iou(p, g) > 0 for p in pset)
570
+
571
+ if overlap:
572
+ stats["miss_with_overlap"] += 1
573
+ else:
574
+ stats["miss"] += 1
575
+
576
+ detailed_errors.append({
577
+ "type": "miss",
578
+ "batch": b,
579
+ "gold": g
580
+ })
581
+
582
+ return {
583
+ "summary": {
584
+ "exact_match": (stats["exact_match"], stats["exact_match"] / len(preds)),
585
+ "partial_match": (stats["partial_match"], stats["partial_match"] / len(preds)),
586
+ "no_event_sample": (stats["no_event_sample"], stats["no_event_sample"] / len(preds)),
587
+ "completely_wrong": (stats["completely_wrong"], stats["completely_wrong"] / len(preds)),
588
+ "miss": (stats["miss"], stats["miss"] / len(golds)),
589
+ "miss_with_overlap": (stats["miss_with_overlap"], stats["miss_with_overlap"] / len(golds)),
590
+ },
591
+ "details": detailed_errors
592
+ }
593
+
594
+ # %% [code]
595
+ class DataParallelProxy(nn.DataParallel):
596
+
597
+ def __getattr__(self, name):
598
+ try:
599
+ return super().__getattr__(name)
600
+
601
+ except AttributeError:
602
+
603
+ attr = getattr(self.module, name)
604
+
605
+ if callable(attr):
606
+
607
+ def wrapper(*args, **kwargs):
608
+ return self._parallel_apply_method(
609
+ name,
610
+ *args,
611
+ **kwargs
612
+ )
613
+
614
+ return wrapper
615
+
616
+ return attr
617
+
618
+ # =========================================================
619
+ # parallel custom method
620
+ # =========================================================
621
+
622
+ def _parallel_apply_method(self, method_name, *inputs, **kwargs):
623
+
624
+ if not self.device_ids:
625
+ return getattr(self.module, method_name)(*inputs, **kwargs)
626
+
627
+ inputs_scattered, kwargs_scattered = self.scatter(
628
+ inputs,
629
+ kwargs,
630
+ self.device_ids
631
+ )
632
+
633
+ replicas = self.replicate(
634
+ self.module,
635
+ self.device_ids[:len(inputs_scattered)]
636
+ )
637
+
638
+ outputs = self.parallel_apply(
639
+ [getattr(replica, method_name) for replica in replicas],
640
+ inputs_scattered,
641
+ kwargs_scattered
642
+ )
643
+
644
+ return self._custom_gather(
645
+ outputs,
646
+ self.output_device
647
+ )
648
+
649
+ # =========================================================
650
+ # OVERRIDE FORWARD GATHER
651
+ # =========================================================
652
+
653
+ def gather(self, outputs, output_device):
654
+
655
+ return self._custom_gather(
656
+ outputs,
657
+ output_device
658
+ )
659
+
660
+ # =========================================================
661
+ # recursive gather
662
+ # =========================================================
663
+
664
+ def _custom_gather(self, outputs, output_device):
665
+
666
+ first = outputs[0]
667
+
668
+ # =====================================================
669
+ # tensor
670
+ # =====================================================
671
+
672
+ if torch.is_tensor(first):
673
+
674
+ return self._gather_tensor(
675
+ outputs,
676
+ output_device
677
+ )
678
+
679
+ # =====================================================
680
+ # tuple
681
+ # =====================================================
682
+
683
+ if isinstance(first, tuple):
684
+
685
+ return tuple(
686
+ self._custom_gather(
687
+ list(items),
688
+ output_device
689
+ )
690
+ for items in zip(*outputs)
691
+ )
692
+
693
+ # =====================================================
694
+ # list
695
+ # =====================================================
696
+
697
+ if isinstance(first, list):
698
+
699
+ # list[tensor]
700
+ if len(first) > 0 and torch.is_tensor(first[0]):
701
+
702
+ return self._gather_tensor_list(
703
+ outputs,
704
+ output_device
705
+ )
706
+
707
+ merged = []
708
+
709
+ for out in outputs:
710
+ merged.extend(out)
711
+
712
+ return merged
713
+
714
+ # =====================================================
715
+ # dict
716
+ # =====================================================
717
+
718
+ if isinstance(first, dict):
719
+
720
+ return {
721
+ k: self._custom_gather(
722
+ [o[k] for o in outputs],
723
+ output_device
724
+ )
725
+ for k in first.keys()
726
+ }
727
+
728
+ # =====================================================
729
+ # fallback
730
+ # =====================================================
731
+
732
+ return outputs
733
+
734
+ # =========================================================
735
+ # gather tensor with auto pad
736
+ # =========================================================
737
+
738
+ def _gather_tensor(self, tensors, output_device):
739
+
740
+ # move same device first
741
+ tensors = [
742
+ t.to(output_device)
743
+ for t in tensors
744
+ ]
745
+
746
+ # =====================================================
747
+ # fast path
748
+ # =====================================================
749
+
750
+ try:
751
+ return torch.cat(tensors, dim=0)
752
+
753
+ except RuntimeError:
754
+ pass
755
+
756
+ # =====================================================
757
+ # auto max shape
758
+ # =====================================================
759
+
760
+ max_shape = list(tensors[0].shape)
761
+
762
+ for t in tensors[1:]:
763
+
764
+ for d in range(len(max_shape)):
765
+
766
+ max_shape[d] = max(
767
+ max_shape[d],
768
+ t.shape[d]
769
+ )
770
+
771
+ # =====================================================
772
+ # pad tensors
773
+ # =====================================================
774
+
775
+ padded = []
776
+
777
+ for t in tensors:
778
+
779
+ pad = []
780
+
781
+ # reverse order for F.pad
782
+ for d in reversed(range(len(max_shape))):
783
+
784
+ # never pad batch dim
785
+ if d == 0:
786
+ pad.extend([0, 0])
787
+ continue
788
+
789
+ diff = max_shape[d] - t.shape[d]
790
+
791
+ pad.extend([0, diff])
792
+
793
+ t = F.pad(t, pad)
794
+
795
+ padded.append(t)
796
+
797
+ return torch.cat(padded, dim=0)
798
+
799
+ # =========================================================
800
+ # list[tensor]
801
+ # =========================================================
802
+
803
+ def _gather_tensor_list(self, outputs, output_device):
804
+
805
+ merged = []
806
+
807
+ for out in outputs:
808
+ merged.extend(out)
809
+
810
+ return self._gather_tensor(
811
+ merged,
812
+ output_device
813
+ )
814
+
815
+ # %% [code]
816
+ ## Viết cấu trúc model vào đây
817
+ def extract_spans_and_labels(start_logits, end_logits):
818
+ """
819
+ Args:
820
+ start_logits: Tensor (B, L, C)
821
+ end_logits: Tensor (B, L, C)
822
+
823
+ Returns:
824
+ spans: Tensor (B, N, 2)
825
+ padding = (0, 0)
826
+
827
+ labels: Tensor (B, N)
828
+ padding = -100
829
+
830
+ Nếu không extract được span nào:
831
+ spans.shape = (B, 0, 2)
832
+ labels.shape = (B, 0)
833
+ """
834
+
835
+ start_labels = start_logits.argmax(dim=-1) # (B, L)
836
+ end_labels = end_logits.argmax(dim=-1) # (B, L)
837
+
838
+ B, L = start_labels.shape
839
+
840
+ batch_spans = []
841
+ batch_labels = []
842
+
843
+ max_n = 0
844
+
845
+ for bidx in range(B):
846
+
847
+ used_start = set()
848
+ used_end = set()
849
+
850
+ spans = []
851
+ labels = []
852
+
853
+ for s in range(L):
854
+
855
+ s_label = start_labels[bidx, s].item()
856
+
857
+ # bỏ qua O
858
+ if s_label == 0:
859
+ continue
860
+
861
+ if s in used_start:
862
+ continue
863
+
864
+ nearest_e = None
865
+
866
+ # tìm end gần nhất cùng class
867
+ for e in range(s, L):
868
+
869
+ if e in used_end:
870
+ continue
871
+
872
+ e_label = end_labels[bidx, e].item()
873
+
874
+ if e_label == s_label:
875
+ nearest_e = e
876
+ break
877
+
878
+ if nearest_e is None:
879
+ continue
880
+
881
+ used_start.add(s)
882
+ used_end.add(nearest_e)
883
+
884
+ spans.append([s, nearest_e])
885
+ labels.append(s_label)
886
+
887
+ batch_spans.append(spans)
888
+ batch_labels.append(labels)
889
+
890
+ max_n = max(max_n, len(spans))
891
+
892
+ # =========================================================
893
+ # không extract được gì
894
+ # =========================================================
895
+
896
+ if max_n == 0:
897
+
898
+ spans = torch.empty(
899
+ (B, 0, 2),
900
+ dtype=torch.long,
901
+ device=start_logits.device
902
+ )
903
+
904
+ labels = torch.empty(
905
+ (B, 0),
906
+ dtype=torch.long,
907
+ device=start_logits.device
908
+ )
909
+
910
+ return spans, labels
911
+
912
+ # =========================================================
913
+ # padding
914
+ # =========================================================
915
+
916
+ padded_spans = []
917
+ padded_labels = []
918
+
919
+ for spans, labels in zip(batch_spans, batch_labels):
920
+
921
+ pad_n = max_n - len(spans)
922
+
923
+ spans = spans + [[0, 0]] * pad_n
924
+ labels = labels + [-100] * pad_n
925
+
926
+ padded_spans.append(spans)
927
+ padded_labels.append(labels)
928
+
929
+ spans = torch.tensor(
930
+ padded_spans,
931
+ dtype=torch.long,
932
+ device=start_logits.device
933
+ ) # (B, N, 2)
934
+
935
+ labels = torch.tensor(
936
+ padded_labels,
937
+ dtype=torch.long,
938
+ device=start_logits.device
939
+ ) # (B, N)
940
+
941
+ return spans, labels
942
+
943
+ def expand_spans(spans, r, attention_mask):
944
+ """
945
+ Args:
946
+ spans: Tensor (B, N, 2)
947
+ r: Tensor (B, N)
948
+ attention_mask: Tensor (B, L)
949
+
950
+ Returns:
951
+ expanded_spans: Tensor (B, N, M, 2)
952
+
953
+ Với:
954
+ M = (2*max_r + 1)^2
955
+
956
+ padding = (0, 0)
957
+
958
+ Nếu N = 0:
959
+ return shape (B, 0, 0, 2)
960
+ """
961
+
962
+ device = spans.device
963
+
964
+ B, N, _ = spans.shape
965
+
966
+ # =========================================================
967
+ # empty case
968
+ # =========================================================
969
+
970
+ if N == 0:
971
+ return torch.empty(
972
+ (B, 0, 0, 2),
973
+ dtype=torch.long,
974
+ device=device
975
+ )
976
+
977
+ # =========================================================
978
+ # max radius
979
+ # =========================================================
980
+
981
+ max_r = r.max().item()
982
+
983
+ # =========================================================
984
+ # create all (ds, de)
985
+ # =========================================================
986
+
987
+ shifts = torch.arange(
988
+ -max_r,
989
+ max_r + 1,
990
+ device=device
991
+ )
992
+
993
+ ds_grid, de_grid = torch.meshgrid(
994
+ shifts,
995
+ shifts,
996
+ indexing='ij'
997
+ )
998
+
999
+ ds = ds_grid.reshape(-1) # (M,)
1000
+ de = de_grid.reshape(-1) # (M,)
1001
+
1002
+ M = ds.size(0)
1003
+
1004
+ # =========================================================
1005
+ # original spans
1006
+ # =========================================================
1007
+
1008
+ spans = spans.unsqueeze(2) # (B, N, 1, 2)
1009
+
1010
+ start = spans[..., 0] # (B, N, 1)
1011
+ end = spans[..., 1] # (B, N, 1)
1012
+
1013
+ # =========================================================
1014
+ # expand independently
1015
+ # =========================================================
1016
+
1017
+ expanded_s = start + ds.view(1, 1, M)
1018
+ expanded_e = end + de.view(1, 1, M)
1019
+
1020
+ expanded = torch.stack(
1021
+ [expanded_s, expanded_e],
1022
+ dim=-1
1023
+ ) # (B, N, M, 2)
1024
+
1025
+ # =========================================================
1026
+ # valid shift range
1027
+ # =========================================================
1028
+
1029
+ valid_shift = (
1030
+ (ds.abs().view(1, 1, M) <= r.unsqueeze(-1))
1031
+ &
1032
+ (de.abs().view(1, 1, M) <= r.unsqueeze(-1))
1033
+ )
1034
+
1035
+ # =========================================================
1036
+ # non-padding spans
1037
+ # =========================================================
1038
+
1039
+ non_pad_span = (
1040
+ spans.squeeze(2).sum(dim=-1) != 0
1041
+ ) # (B, N)
1042
+
1043
+ # =========================================================
1044
+ # boundary check
1045
+ # =========================================================
1046
+
1047
+ valid_length = attention_mask.sum(dim=-1)
1048
+ valid_length = valid_length.view(B, 1, 1)
1049
+
1050
+ valid_boundary = (
1051
+ (expanded_s > 0)
1052
+ &
1053
+ (expanded_e < valid_length)
1054
+ &
1055
+ (expanded_s <= expanded_e)
1056
+ )
1057
+
1058
+ # =========================================================
1059
+ # final valid mask
1060
+ # =========================================================
1061
+
1062
+ valid_mask = (
1063
+ valid_shift
1064
+ &
1065
+ non_pad_span.unsqueeze(-1)
1066
+ &
1067
+ valid_boundary
1068
+ )
1069
+
1070
+ # =========================================================
1071
+ # flatten valid spans
1072
+ # =========================================================
1073
+
1074
+ expanded_flat = expanded.reshape(B, N, M, 2)
1075
+ valid_flat = valid_mask.reshape(B, N, M)
1076
+
1077
+ # =========================================================
1078
+ # số valid max toàn batch
1079
+ # =========================================================
1080
+
1081
+ valid_counts = valid_flat.sum(dim=-1) # (B, N)
1082
+
1083
+ M_max = valid_counts.max().item()
1084
+
1085
+ # =========================================================
1086
+ # allocate output
1087
+ # =========================================================
1088
+
1089
+ filtered_expanded = torch.zeros(
1090
+ (B, N, M_max, 2),
1091
+ dtype=expanded.dtype,
1092
+ device=device
1093
+ )
1094
+
1095
+ # =========================================================
1096
+ # gather only valid spans
1097
+ # =========================================================
1098
+
1099
+ for b in range(B):
1100
+ for n in range(N):
1101
+
1102
+ cur_valid = valid_flat[b, n] # (M,)
1103
+
1104
+ valid_spans = expanded_flat[b, n][cur_valid]
1105
+
1106
+ k = valid_spans.size(0)
1107
+
1108
+ if k > 0:
1109
+ filtered_expanded[b, n, :k] = valid_spans
1110
+
1111
+ return filtered_expanded
1112
+
1113
+ def get_span_reprs(hidden, spans):
1114
+ """
1115
+ Args:
1116
+ hidden: (B, L, H)
1117
+ spans: (B, N, M, 2)
1118
+
1119
+ Return:
1120
+ span_repr: (B, N, M, 4H)
1121
+
1122
+ Nếu N = 0:
1123
+ return shape (B, 0, 0, 4H)
1124
+ """
1125
+
1126
+ B, N, M, _ = spans.shape
1127
+ H = hidden.size(-1)
1128
+
1129
+ # =========================================================
1130
+ # empty case
1131
+ # =========================================================
1132
+
1133
+ if N == 0:
1134
+ return torch.empty(
1135
+ (B, 0, 0, 4 * H),
1136
+ dtype=hidden.dtype,
1137
+ device=hidden.device
1138
+ )
1139
+
1140
+ batch_idx = torch.arange(
1141
+ B,
1142
+ device=hidden.device
1143
+ ).view(B, 1, 1)
1144
+
1145
+ start_idx = spans[..., 0] # (B, N, M)
1146
+ end_idx = spans[..., 1] # (B, N, M)
1147
+
1148
+ # =========================================================
1149
+ # gather hidden
1150
+ # =========================================================
1151
+
1152
+ start_h = hidden[batch_idx, start_idx]
1153
+ end_h = hidden[batch_idx, end_idx]
1154
+
1155
+ # (B, N, M, H)
1156
+
1157
+ # =========================================================
1158
+ # span representation
1159
+ # =========================================================
1160
+
1161
+ span_repr = torch.cat(
1162
+ [
1163
+ start_h,
1164
+ end_h,
1165
+ end_h - start_h,
1166
+ end_h * start_h
1167
+ ],
1168
+ dim=-1
1169
+ ) # (B, N, M, 4H)
1170
+
1171
+ return span_repr
1172
+
1173
+ class MLP(nn.Module):
1174
+ def __init__(self, in_size, hid_size, out_size):
1175
+ super().__init__()
1176
+ self.mlp = nn.Sequential(
1177
+ nn.Linear(in_size, hid_size),
1178
+ nn.ReLU(),
1179
+ nn.Linear(hid_size, out_size)
1180
+ )
1181
+
1182
+ def forward(self, x):
1183
+ return self.mlp(x)
1184
+
1185
+ class IEModel(nn.Module):
1186
+ def __init__(self, backbone_model_name, num_labels, max_r):
1187
+ super().__init__()
1188
+ self.max_r = max_r
1189
+
1190
+ self.encoder = AutoModel.from_pretrained(backbone_model_name)
1191
+ hidden_size = self.encoder.config.hidden_size
1192
+
1193
+ self.start_classifier = MLP(hidden_size, hidden_size, num_labels)
1194
+ self.end_classifier = MLP(hidden_size, hidden_size, num_labels)
1195
+
1196
+ self.span_scorer = MLP(4*hidden_size, hidden_size, 1)
1197
+
1198
+ def encode(self, input_ids, attention_mask):
1199
+ B, n_parts, L = input_ids.shape
1200
+ input_ids = input_ids.view(-1, L)
1201
+ attention_mask = attention_mask.view(-1, L)
1202
+
1203
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
1204
+ hidden_states = outputs.last_hidden_state # B * n_parts, L, H
1205
+
1206
+ hidden_states = hidden_states.view(B, n_parts, L, -1).reshape(B, n_parts*L, -1) # B, L, H
1207
+ attention_mask = attention_mask.view(B, n_parts, L).reshape(B, n_parts*L) # B, L
1208
+ return hidden_states, attention_mask
1209
+
1210
+ def get_logits(self, hidden_states):
1211
+ start_logits = self.start_classifier(hidden_states) # B, N, classes
1212
+ end_logits = self.end_classifier(hidden_states) # B, N, classes
1213
+ return start_logits, end_logits
1214
+
1215
+ def get_scores(self, expanded_span_reprs):
1216
+ return self.span_scorer(expanded_span_reprs).squeeze(-1)
1217
+
1218
+ def forward(self, input_ids, attention_mask, spans=None):
1219
+ hidden_states, attention_mask = self.encode(input_ids, attention_mask)
1220
+ start_logits, end_logits = self.get_logits(hidden_states)
1221
+
1222
+ if spans is None:
1223
+ spans, _ = extract_spans_and_labels(start_logits, end_logits) # (B, N, 2)
1224
+ B, N, _ = spans.shape
1225
+ r = torch.randint(1, self.max_r+1, (B, N), device=spans.device)
1226
+ expanded_spans = expand_spans(spans, r, attention_mask) # (B, N, M, 2)
1227
+
1228
+ expanded_span_reprs = get_span_reprs(hidden_states, expanded_spans) # (B, N, M, 4H)
1229
+ expanded_span_scores = self.get_scores(expanded_span_reprs) # (B, N, M)
1230
+
1231
+ return start_logits, end_logits, expanded_spans, expanded_span_scores
1232
+
1233
+ def test_model():
1234
+ model = DataParallelProxy(IEModel(backbone_model_name, 7, 6)).to(device)
1235
+ model.eval()
1236
+ total_params = sum(p.numel() for p in model.parameters())
1237
+ print(f"Total params: {total_params:,}")
1238
+
1239
+ vocab_size = model.module.encoder.config.vocab_size
1240
+ max_len = model.module.encoder.config.max_position_embeddings
1241
+
1242
+ bz = 32
1243
+ i = torch.randint(0, vocab_size, (bz, 5, 10)).to(device)
1244
+ a = torch.ones(bz, 5, 10).to(device)
1245
+ g = torch.ones(bz, 3, 2, dtype=torch.long).to(device)
1246
+
1247
+ with torch.no_grad():
1248
+ r = model(i, a)
1249
+
1250
+ if type(r) == tuple:
1251
+ print([r[i].shape if type(r[i]) == type(torch.Tensor()) else len(r[i]) for i in range(len(r))])
1252
+ else:
1253
+ print(r.shape)
1254
+
1255
+ test_model()
1256
+
1257
+ # %% [code]
1258
+ def configure_optimizers(network, optim_params, scheduler_params):
1259
+ try:
1260
+ optim_params = copy.copy(optim_params)
1261
+ scheduler_params = copy.copy(scheduler_params)
1262
+
1263
+ optim_name = optim_params.pop('name')
1264
+ scheduler_name = scheduler_params.pop('name')
1265
+
1266
+ optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None)
1267
+ scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None)
1268
+
1269
+ if optimizer_cls is None:
1270
+ raise ValueError(f"Optimizer '{optim_name}' is not available!")
1271
+
1272
+ optimizer = optimizer_cls(network.parameters(), **optim_params)
1273
+
1274
+ scheduler = None
1275
+ if scheduler_params and scheduler_cls: # Chỉ tạo scheduler nếu có tham số
1276
+ scheduler = scheduler_cls(optimizer, **scheduler_params)
1277
+
1278
+ return optimizer, scheduler
1279
+
1280
+ except KeyError as e:
1281
+ raise ValueError(f"Missing {e} in config!!")
1282
+
1283
+ def freeze(self, model):
1284
+ model.eval()
1285
+ for param in model.parameters():
1286
+ param.requires_grad = False
1287
+
1288
+ def unfreeze(self, model):
1289
+ model.train()
1290
+ for param in model.parameters():
1291
+ param.requires_grad = True
1292
+
1293
+ def reduce_batch_size(loader, ratio=0.5):
1294
+ new_bs = max(1, int(loader.batch_size * ratio))
1295
+
1296
+ shuffle = isinstance(loader.sampler, RandomSampler)
1297
+
1298
+ new_loader = DataLoader(
1299
+ dataset=loader.dataset,
1300
+ batch_size=new_bs,
1301
+ shuffle=shuffle,
1302
+ sampler=None if shuffle else loader.sampler,
1303
+ num_workers=loader.num_workers,
1304
+ collate_fn=loader.collate_fn,
1305
+ pin_memory=loader.pin_memory,
1306
+ drop_last=loader.drop_last,
1307
+ timeout=loader.timeout,
1308
+ worker_init_fn=loader.worker_init_fn,
1309
+ multiprocessing_context=loader.multiprocessing_context,
1310
+ generator=loader.generator,
1311
+ prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None,
1312
+ persistent_workers=loader.persistent_workers,
1313
+ pin_memory_device=loader.pin_memory_device
1314
+ )
1315
+
1316
+ return new_loader
1317
+
1318
+ def list_to_tuple(x):
1319
+ if isinstance(x, (list, tuple)):
1320
+ return tuple(list_to_tuple(i) for i in x)
1321
+ return x
1322
+
1323
+ def fmt(x):
1324
+ if isinstance(x, float):
1325
+ return round(x, 5)
1326
+ if isinstance(x, dict):
1327
+ return {k: fmt(v) for k, v in x.items()}
1328
+ if isinstance(x, list):
1329
+ return [fmt(v) for v in x]
1330
+ return x
1331
+
1332
+ class ModelEmaV3Proxy(ModelEmaV3):
1333
+ def __getattr__(self, name):
1334
+ try:
1335
+ return super().__getattr__(name)
1336
+ except AttributeError:
1337
+ return getattr(self.module, name)
1338
+
1339
+ def align(spans, gold_spans, gold_labels):
1340
+ """
1341
+ Args:
1342
+ spans: (B, N1, M, 2)
1343
+ gold_spans: (B, N2, 2)
1344
+ gold_labels: (B, N2)
1345
+
1346
+ Returns:
1347
+ labels: (B, N1, M)
1348
+
1349
+ matched -> gold_labels
1350
+ padding (0,0) -> -100
1351
+ unmatched -> 0
1352
+ """
1353
+
1354
+ device = spans.device
1355
+
1356
+ B, N1, M, _ = spans.shape
1357
+ N2 = gold_spans.shape[1]
1358
+
1359
+ # =========================================================
1360
+ # compare spans
1361
+ # =========================================================
1362
+
1363
+ pred = spans.unsqueeze(3)
1364
+ # (B, N1, M, 1, 2)
1365
+
1366
+ gold = gold_spans.unsqueeze(1).unsqueeze(1)
1367
+ # (B, 1, 1, N2, 2)
1368
+
1369
+ matched = (pred == gold).all(dim=-1)
1370
+ # (B, N1, M, N2)
1371
+
1372
+ # =========================================================
1373
+ # default labels = 0
1374
+ # =========================================================
1375
+
1376
+ labels = torch.zeros(
1377
+ (B, N1, M),
1378
+ dtype=gold_labels.dtype,
1379
+ device=device
1380
+ )
1381
+
1382
+ # =========================================================
1383
+ # assign matched labels
1384
+ # =========================================================
1385
+
1386
+ has_match = matched.any(dim=-1)
1387
+ # (B, N1, M)
1388
+
1389
+ matched_idx = matched.float().argmax(dim=-1)
1390
+ # (B, N1, M)
1391
+
1392
+ expanded_gold_labels = gold_labels.unsqueeze(1).unsqueeze(1).expand(
1393
+ B, N1, M, N2
1394
+ )
1395
+
1396
+ gathered_labels = torch.gather(
1397
+ expanded_gold_labels,
1398
+ dim=-1,
1399
+ index=matched_idx.unsqueeze(-1)
1400
+ ).squeeze(-1)
1401
+
1402
+ labels = torch.where(
1403
+ has_match,
1404
+ gathered_labels,
1405
+ labels
1406
+ )
1407
+
1408
+ # =========================================================
1409
+ # padding span -> -100
1410
+ # =========================================================
1411
+
1412
+ is_padding = (spans.sum(dim=-1) == 0)
1413
+
1414
+ labels = torch.where(
1415
+ is_padding,
1416
+ torch.full_like(labels, -100),
1417
+ labels
1418
+ )
1419
+
1420
+ return labels
1421
+
1422
+ def select_best_spans(spans, span_scores):
1423
+ """
1424
+ Args:
1425
+ spans: (B, N, M, 2)
1426
+ span_scores: (B, N, M)
1427
+
1428
+ Returns:
1429
+ best_spans: (B, N, 2)
1430
+ """
1431
+
1432
+ # =========================================================
1433
+ # padding spans -> -inf
1434
+ # =========================================================
1435
+
1436
+ is_padding = (spans.sum(dim=-1) == 0)
1437
+ # (B, N, M)
1438
+
1439
+ masked_scores = span_scores.masked_fill(
1440
+ is_padding,
1441
+ float("-inf")
1442
+ )
1443
+
1444
+ # =========================================================
1445
+ # best index theo chiều M
1446
+ # =========================================================
1447
+
1448
+ best_idx = masked_scores.argmax(dim=-1)
1449
+ # (B, N)
1450
+
1451
+ # =========================================================
1452
+ # gather spans
1453
+ # =========================================================
1454
+
1455
+ best_spans = torch.gather(
1456
+ spans,
1457
+ dim=2,
1458
+ index=best_idx.unsqueeze(-1).unsqueeze(-1).expand(
1459
+ -1, -1, 1, 2
1460
+ )
1461
+ ).squeeze(2)
1462
+
1463
+ # =========================================================
1464
+ # nếu toàn bộ M đều là padding
1465
+ # -> output = (0,0)
1466
+ # =========================================================
1467
+
1468
+ all_padding = is_padding.all(dim=-1)
1469
+ # (B, N)
1470
+
1471
+ best_spans = torch.where(
1472
+ all_padding.unsqueeze(-1),
1473
+ torch.zeros_like(best_spans),
1474
+ best_spans
1475
+ )
1476
+
1477
+ return best_spans
1478
+
1479
+ def extract_entities(input_ids, spans, labels, id2label):
1480
+ """
1481
+ Args:
1482
+ input_ids: (B, L)
1483
+ spans: (B, N, 2)
1484
+ labels: (B, N)
1485
+
1486
+ Returns:
1487
+ List[
1488
+ (
1489
+ bidx,
1490
+ (
1491
+ tuple(token_ids),
1492
+ label_name
1493
+ )
1494
+ )
1495
+ ]
1496
+ """
1497
+
1498
+ B, N, _ = spans.shape
1499
+
1500
+ results = []
1501
+
1502
+ for bidx in range(B):
1503
+
1504
+ for n in range(N):
1505
+
1506
+ lb = labels[bidx, n].item()
1507
+
1508
+ # ignore padding / negative
1509
+ if lb <= 0:
1510
+ continue
1511
+
1512
+ s, e = spans[bidx, n].tolist()
1513
+
1514
+ # skip padding span
1515
+ if s == 0 and e == 0:
1516
+ continue
1517
+
1518
+ token_ids = tuple(
1519
+ input_ids[bidx, s:e + 1].tolist()
1520
+ )
1521
+
1522
+ results.append(
1523
+ (
1524
+ bidx,
1525
+ (
1526
+ token_ids,
1527
+ id2label[lb]
1528
+ )
1529
+ )
1530
+ )
1531
+
1532
+ return results
1533
+
1534
+ class Trainer:
1535
+ def __init__(
1536
+ self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0,
1537
+ logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu',
1538
+ schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True
1539
+ ):
1540
+ self.ema_net = None
1541
+
1542
+ self.training_time = self._time_str_to_seconds(training_time)
1543
+ self.mode = eval_mode
1544
+ self.topk = topk
1545
+ self.device = device
1546
+ self.logging = logging if logging < epochs else 1
1547
+ self.logging_file = logging_file
1548
+ self.checkpoints_dir = checkpoints_dir
1549
+ self.early_stopping = early_stopping
1550
+ self.eval_from_ratio = eval_from_ratio
1551
+ self.eval_every = eval_every
1552
+ self.save_name = save_name
1553
+ self.save_best = save_best
1554
+ self.save_last = save_last
1555
+ self.return_best = return_best
1556
+ self.return_last = return_last
1557
+ self.max_grad_norm = max_grad_norm
1558
+ self.schedule_in_step = schedule_in_step
1559
+ self.use_ema = use_ema
1560
+ self.ema_from_ratio = ema_from_ratio
1561
+ self.ema_decay = ema_decay
1562
+
1563
+ self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]]
1564
+ self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0)
1565
+
1566
+ def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, id2label=None):
1567
+ if eval_fn is None:
1568
+ if self.mode == "max":
1569
+ eval_fn = lambda *x: -loss_fn(*x)
1570
+ else:
1571
+ eval_fn = lambda *x: loss_fn(*x)
1572
+
1573
+ if torch.cuda.device_count() > 1:
1574
+ network = DataParallelProxy(network)
1575
+ network = network.to(self.device)
1576
+
1577
+ if not start_training_time:
1578
+ start_training_time = time.time()
1579
+
1580
+ start_ema = int(epochs * self.ema_from_ratio)
1581
+ start_eval = int(epochs * self.eval_from_ratio)
1582
+
1583
+ if val_loader is None:
1584
+ print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!')
1585
+ else:
1586
+ model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc'
1587
+ start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else ''
1588
+ print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str)
1589
+
1590
+ training_log = {}
1591
+ for epoch in range(start_epoch, epochs+start_epoch):
1592
+ if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema:
1593
+ self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
1594
+
1595
+ try:
1596
+ teaching_rate = math.cos(math.pi / 2 * epoch / epochs)
1597
+ train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, teaching_rate)
1598
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
1599
+ logging_dict.update(train_loss_epoch_dict)
1600
+
1601
+ if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0:
1602
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1603
+
1604
+ val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, id2label)
1605
+ update = self._update_best_network(eval_net, val_score, epoch)
1606
+ logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update})
1607
+ logging_dict.update(val_score_dict)
1608
+ if not self.schedule_in_step and scheduler:
1609
+ scheduler.step()
1610
+
1611
+ except RuntimeError as e:
1612
+ if "out of memory" in str(e).lower():
1613
+ print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...")
1614
+ torch.cuda.empty_cache()
1615
+ gc.collect()
1616
+ if torch.cuda.is_available():
1617
+ torch.cuda.synchronize()
1618
+ print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared")
1619
+
1620
+ train_loader = reduce_batch_size(train_loader, ratio=0.5)
1621
+ if val_loader is not None:
1622
+ val_loader = reduce_batch_size(val_loader, ratio=0.5)
1623
+
1624
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')}
1625
+ else:
1626
+ raise
1627
+
1628
+ training_log[epoch] = logging_dict
1629
+ if self.is_early_stopping(epoch):
1630
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...')
1631
+ break
1632
+ if self.logging:
1633
+ if epoch % self.logging == 0:
1634
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict))
1635
+ else:
1636
+ print(f'{epoch}...', end=' ')
1637
+
1638
+ if self._at_time_limit(start_training_time):
1639
+ print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}')
1640
+ break
1641
+
1642
+ if self.logging_file:
1643
+ os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True)
1644
+ with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f:
1645
+ f.write(json.dumps(training_log))
1646
+
1647
+ if self.use_ema and self.ema_net is not None:
1648
+ self._save_state_dict(self.ema_net.module)
1649
+ else:
1650
+ self._save_state_dict(network)
1651
+ print(f'[Trainer CallBack] 📢 Kết thúc training.\n')
1652
+
1653
+ best_model, last_model = None, None
1654
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1655
+ if self.return_best :
1656
+ best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict()
1657
+ best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()}
1658
+ if self.return_last:
1659
+ last_model = eval_net.state_dict()
1660
+ last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()}
1661
+
1662
+ del network
1663
+ torch.cuda.empty_cache()
1664
+ gc.collect()
1665
+ return training_log, best_model, last_model
1666
+
1667
+ def _time_str_to_seconds(self, time_str):
1668
+ days, hours, minutes, seconds = map(int, time_str.split(":"))
1669
+ return days * 86400 + hours * 3600 + minutes * 60 + seconds
1670
+
1671
+ def _update_best_network(self, network, val_score, epoch):
1672
+ topk = max(1, self.topk)
1673
+ self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}])
1674
+ self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk]
1675
+ if val_score in [x[0] for x in self.best_stage]:
1676
+ return True
1677
+ return False
1678
+
1679
+ def is_early_stopping(self, epoch):
1680
+ if self.best_stage[0][1] is None:
1681
+ return False
1682
+ if not self.early_stopping:
1683
+ return False
1684
+ return epoch - self.best_stage[0][1] >= self.early_stopping
1685
+
1686
+ def _at_time_limit(self, start_training_time):
1687
+ return time.time() - start_training_time >= self.training_time
1688
+
1689
+ def _save_state_dict(self, network):
1690
+ if self.topk <= 0:
1691
+ return
1692
+
1693
+ if self.save_best:
1694
+ for r in range(self.topk):
1695
+ os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True)
1696
+
1697
+ for rank, (score, epoch, state_dict) in enumerate(self.best_stage):
1698
+ if state_dict is None:
1699
+ continue
1700
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()}
1701
+ torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth')
1702
+ if self.save_last:
1703
+ os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True)
1704
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()}
1705
+ torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth')
1706
+
1707
+ def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn, teaching_rate):
1708
+ network.train()
1709
+ total_loss = 0
1710
+ total_loss_dict = {}
1711
+ for batch_idx, batch in enumerate(train_loader):
1712
+ optimizer.zero_grad()
1713
+ with torch.autocast(device_type=self.device, dtype=torch.float16):
1714
+ loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn, teaching_rate)
1715
+
1716
+ for k, v in loss_dict.items():
1717
+ t = total_loss_dict.get(k, 0)
1718
+ total_loss_dict[k] = t + v
1719
+ self.grad_scaler.scale(loss).backward()
1720
+ self.grad_scaler.unscale_(optimizer)
1721
+ grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm)
1722
+ # print(grad_norm) # Bỏ cmt dòng này để biết nên chọn max_grad_norm bằng bao nhiêu...
1723
+ self.grad_scaler.step(optimizer)
1724
+ self.grad_scaler.update()
1725
+ if self.schedule_in_step and scheduler:
1726
+ scheduler.step()
1727
+ if self.use_ema and self.ema_net is not None:
1728
+ self.ema_net.update(network)
1729
+ total_loss += loss
1730
+ return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()}
1731
+
1732
+ def _eval_epoch(self, network, val_loader, eval_fn, id2label):
1733
+ network.eval()
1734
+ total_score = 0.0
1735
+ total_score_dict = {}
1736
+ object_lists = None # sẽ init sau
1737
+
1738
+ with torch.no_grad():
1739
+ for batch_idx, batch in enumerate(val_loader):
1740
+ score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, id2label)
1741
+ total_score += score
1742
+
1743
+ for k, v in score_dict.items():
1744
+ t = total_score_dict.get(k, 0)
1745
+ total_score_dict[k] = t + v
1746
+
1747
+ if objects:
1748
+ if object_lists is None:
1749
+ object_lists = [[] for _ in range(len(objects))]
1750
+
1751
+ for i, obj in enumerate(objects):
1752
+ object_lists[i].append(obj.detach())
1753
+
1754
+ if object_lists is not None:
1755
+ object_arrays = [
1756
+ torch.concat(obj_list, dim=0).cpu().numpy()
1757
+ for obj_list in object_lists
1758
+ ]
1759
+ else:
1760
+ object_arrays = []
1761
+
1762
+ return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays
1763
+
1764
+ def _cal_loss(self, network, batch, batch_idx, loss_fn, teaching_rate):
1765
+ # Bạn cần override _cal_loss để tính loss
1766
+ input_ids = batch['input_ids'].to(self.device)
1767
+ attention_mask = batch['attention_mask'].to(self.device)
1768
+ gold_spans = batch['gold_spans'].to(self.device)
1769
+ gold_labels = batch['gold_labels'].to(self.device)
1770
+ start_labels = batch['start_labels'].to(self.device)
1771
+ end_labels = batch['end_labels'].to(self.device)
1772
+
1773
+ choice = random.random()
1774
+ if choice < teaching_rate:
1775
+ start_logits, end_logits, spans, span_scores = network(input_ids, attention_mask, gold_spans)
1776
+ else:
1777
+ start_logits, end_logits, spans, span_scores = network(input_ids, attention_mask)
1778
+
1779
+ labels = align(spans, gold_spans, gold_labels)
1780
+
1781
+ loss_dict = loss_fn(
1782
+ start_logits, start_labels,
1783
+ end_logits, end_labels,
1784
+ span_scores, labels
1785
+ )
1786
+ return loss_dict['total'], loss_dict
1787
+
1788
+ def _cal_val_score(self, network, batch, batch_idx, eval_fn, id2label):
1789
+ # Bạn cần override _cal_val_score để tính val score, list bên cạnh là để trả về y hay pred gì đó (nếu cần)
1790
+ input_ids = batch['input_ids'].to(self.device)
1791
+ attention_mask = batch['attention_mask'].to(self.device)
1792
+ gold_entities = batch['gold_entities']
1793
+
1794
+ B, _, _ = input_ids.shape
1795
+
1796
+ start_logits, end_logits, spans, span_scores = network(input_ids, attention_mask)
1797
+ _, labels = extract_spans_and_labels(start_logits, end_logits)
1798
+ spans = select_best_spans(spans, span_scores)
1799
+
1800
+ pred_ids = extract_entities(input_ids.reshape(B, -1), spans, labels, id2label)
1801
+ pred_ids = list_to_tuple(pred_ids)
1802
+
1803
+ gold_ids = list_to_tuple(gold_entities)
1804
+
1805
+ score_dict = eval_fn(pred_ids, gold_ids)
1806
+ return score_dict['f1'], score_dict, []
1807
+
1808
+ # %% [code]
1809
+ class PhoBERTSpanAligner:
1810
+ def __init__(self, tokenizer, max_len):
1811
+ self.tokenizer = tokenizer
1812
+ self.max_len = max_len
1813
+
1814
+ # ===== 1. Extract discontinuous spans =====
1815
+ def extract_spans(self, sample):
1816
+ entity_spans = []
1817
+
1818
+ for event in sample["entities"]:
1819
+ entity_type = event["label"]
1820
+ spans = [tuple(event["offset"])]
1821
+ entity_spans.append({
1822
+ "spans": spans,
1823
+ "label": entity_type
1824
+ })
1825
+
1826
+ return entity_spans
1827
+
1828
+ # ===== 2. Word offsets =====
1829
+ def build_word_offsets(self, text, words):
1830
+ offsets = []
1831
+ pointer = 0
1832
+
1833
+ for word in words:
1834
+ start = text.find(word, pointer)
1835
+ end = start + len(word)
1836
+ offsets.append((start, end))
1837
+ pointer = end
1838
+
1839
+ return offsets
1840
+
1841
+ # ===== 3. Char → word =====
1842
+ def char_span_to_word_span(self, word_offsets, start, end):
1843
+ start_word = None
1844
+ end_word = None
1845
+
1846
+ for i, (w_start, w_end) in enumerate(word_offsets):
1847
+ if w_start <= start < w_end:
1848
+ start_word = i
1849
+ if w_start < end <= w_end:
1850
+ end_word = i
1851
+
1852
+ return start_word, end_word
1853
+
1854
+ # ===== 4. Word → subword =====
1855
+ def word_to_subword_map(self, words):
1856
+ mapping = []
1857
+ subword_index = 1 # <s>
1858
+
1859
+ for word in words:
1860
+ sub_tokens = self.tokenizer.tokenize(word)
1861
+ start = subword_index
1862
+ end = subword_index + len(sub_tokens) - 1
1863
+ mapping.append((start, end))
1864
+ subword_index += len(sub_tokens)
1865
+
1866
+ return mapping
1867
+
1868
+ # ===== 5. Span → subword =====
1869
+ def span_to_subword(self, word_offsets, word_subword_map, spans):
1870
+ sub_spans = []
1871
+
1872
+ for span_start, span_end in spans:
1873
+ w_start, w_end = self.char_span_to_word_span(
1874
+ word_offsets, span_start, span_end
1875
+ )
1876
+ if w_start is None or w_end is None:
1877
+ continue
1878
+
1879
+ sub_start = word_subword_map[w_start][0]
1880
+ sub_end = word_subword_map[w_end][1]
1881
+ sub_spans.append((sub_start, sub_end))
1882
+
1883
+ return sub_spans
1884
+
1885
+ def extract_valid_spans(self, sub_spans):
1886
+ valid_spans = []
1887
+ for s, e in sub_spans:
1888
+ if s < 0 or e < 0 or s >= self.max_len or e >= self.max_len or s > e:
1889
+ continue
1890
+ valid_spans.append((s, e))
1891
+ return valid_spans
1892
+
1893
+ def encode(self, sample):
1894
+ text = sample["text"]
1895
+ entities = self.extract_spans(sample)
1896
+
1897
+ # ===== 1. Word tokenize =====
1898
+ words = word_tokenize(text)
1899
+ sentence = " ".join(words)
1900
+
1901
+ # ===== 2. Mapping =====
1902
+ word_offsets = self.build_word_offsets(text, words)
1903
+ word_subword_map = self.word_to_subword_map(words)
1904
+
1905
+ # ===== 3. Tokenize FULL =====
1906
+ encoding = self.tokenizer(
1907
+ sentence,
1908
+ max_length=self.max_len,
1909
+ truncation=True,
1910
+ padding="max_length",
1911
+ return_tensors="pt"
1912
+ )
1913
+ input_ids = encoding["input_ids"][0]
1914
+ attention_mask = encoding["attention_mask"][0]
1915
+
1916
+ # ===== 5. Convert spans =====
1917
+ entities_gold_spans = []
1918
+
1919
+ for ent in entities:
1920
+ label = ent["label"]
1921
+
1922
+ sub_spans = self.span_to_subword(
1923
+ word_offsets,
1924
+ word_subword_map,
1925
+ ent["spans"]
1926
+ )
1927
+ valid_spans = self.extract_valid_spans(sub_spans)
1928
+ if len(valid_spans) == 0:
1929
+ continue
1930
+ entities_gold_spans.append((tuple(valid_spans), label))
1931
+
1932
+ return {
1933
+ "input_ids": input_ids,
1934
+ "attention_mask": attention_mask,
1935
+ "entities_gold_spans": entities_gold_spans,
1936
+ }
1937
+
1938
+ def generate_spans(attention_mask, max_span_len):
1939
+ seq_len = attention_mask.sum().item() - 2
1940
+ spans = []
1941
+ for i in range(1, seq_len+1):
1942
+ for j in range(i, min(i+max_span_len, seq_len+1)):
1943
+ spans.append((i, j))
1944
+ return spans
1945
+
1946
+ def match_gold_labels(
1947
+ gold_spans, # (N, 2)
1948
+ gold_labels, # (N,)
1949
+ pred_spans, # (M, 2)
1950
+ default_label=-100
1951
+ ):
1952
+ """
1953
+ Return:
1954
+ pred_labels: (M,)
1955
+ """
1956
+
1957
+ pred_labels = torch.full(
1958
+ (pred_spans.size(0),),
1959
+ default_label,
1960
+ dtype=gold_labels.dtype,
1961
+ device=gold_labels.device
1962
+ )
1963
+ if gold_spans.size(0) == 0:
1964
+ return pred_labels
1965
+
1966
+ # (M, N)
1967
+ matched = (pred_spans[:, None, :] == gold_spans[None, :, :]).all(dim=-1)
1968
+ has_match = matched.any(dim=1)
1969
+
1970
+ # lấy index gold đầu tiên match
1971
+ gold_idx = matched.float().argmax(dim=1)
1972
+
1973
+ pred_labels[has_match] = gold_labels[gold_idx[has_match]]
1974
+
1975
+ return pred_labels
1976
+
1977
+ class KLTNDataset(Dataset):
1978
+ def __init__(self, all_data, using_idxes, label2id, tokenizer, max_len, max_n_parts):
1979
+ super().__init__()
1980
+ self.tokenizer = tokenizer
1981
+ self.aligner = PhoBERTSpanAligner(tokenizer, max_len*max_n_parts)
1982
+ self.all_data = all_data
1983
+ self.using_idxes = using_idxes
1984
+ self.label2id = label2id
1985
+ self.max_len = max_len
1986
+ self.max_n_parts = max_n_parts
1987
+
1988
+ def __len__(self):
1989
+ return len(self.using_idxes)
1990
+
1991
+ def __getitem__(self, idx):
1992
+ ridx = self.using_idxes[idx]
1993
+ sample = self.all_data[ridx]
1994
+ result = self.aligner.encode(sample)
1995
+
1996
+ input_ids = result["input_ids"].squeeze(0)
1997
+ attention_mask = result["attention_mask"].squeeze(0)
1998
+ entities_gold_spans = result["entities_gold_spans"]
1999
+
2000
+ all_spans = torch.tensor(generate_spans(attention_mask, 10))
2001
+ gold_spans = torch.tensor([spans[0] for spans, _ in entities_gold_spans], dtype=torch.long) if entities_gold_spans else torch.empty(0, 2, dtype=torch.long)
2002
+ gold_labels = torch.tensor([self.label2id[label] for _, label in entities_gold_spans], dtype=torch.long) if entities_gold_spans else torch.empty(0, dtype=torch.long)
2003
+ all_labels = match_gold_labels(
2004
+ gold_spans, # (N, 2)
2005
+ gold_labels, # (N,)
2006
+ all_spans, # (M, 2)
2007
+ default_label=0
2008
+ )
2009
+
2010
+ # Get label
2011
+ gold_entities = []
2012
+ start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
2013
+ end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
2014
+ for spans, label in entities_gold_spans:
2015
+ s, e = spans[0]
2016
+
2017
+ start_labels[s] = self.label2id[f'{label}']
2018
+ end_labels[e] = self.label2id[f'{label}']
2019
+
2020
+ gold_entities.append((tuple(input_ids[s:e+1].tolist()), label))
2021
+
2022
+ input_ids = input_ids.reshape(self.max_n_parts, self.max_len)
2023
+ attention_mask = attention_mask.reshape(self.max_n_parts, self.max_len)
2024
+
2025
+ n_valid_parts = math.ceil(attention_mask.sum().item() / self.max_len)
2026
+ input_ids = input_ids[:n_valid_parts]
2027
+ attention_mask = attention_mask[:n_valid_parts]
2028
+ start_labels = start_labels[:n_valid_parts*self.max_len]
2029
+ end_labels = end_labels[:n_valid_parts*self.max_len]
2030
+
2031
+ return {
2032
+ "input_ids": input_ids,
2033
+ "attention_mask": attention_mask,
2034
+
2035
+ "all_spans": all_spans,
2036
+ "all_labels": all_labels,
2037
+
2038
+ "gold_spans": gold_spans,
2039
+ "gold_labels": gold_labels,
2040
+
2041
+ "start_labels": start_labels,
2042
+ "end_labels": end_labels,
2043
+
2044
+ "gold_entities": gold_entities,
2045
+ }
2046
+
2047
+ def _pad_batch(tensor_list, pad_value=0):
2048
+ """
2049
+ tensor_list: list of tensors
2050
+ mỗi tensor shape: (Nk, n_parts_i, max_len_i)
2051
+
2052
+ return:
2053
+ padded tensor shape: (B, max_Nk, max_n_parts, max_len)
2054
+ """
2055
+
2056
+ # lấy max toàn batch
2057
+ max_Nk = max(t.size(0) for t in tensor_list)
2058
+ max_n_parts = max(t.size(1) for t in tensor_list)
2059
+ max_len = max(t.size(2) for t in tensor_list)
2060
+
2061
+ padded = []
2062
+
2063
+ for t in tensor_list:
2064
+ Nk, n_parts_i, max_len_i = t.shape
2065
+
2066
+ # pad chiều n_parts và max_len trước
2067
+ if n_parts_i < max_n_parts or max_len_i < max_len:
2068
+ new_t = t.new_full(
2069
+ (Nk, max_n_parts, max_len),
2070
+ pad_value
2071
+ )
2072
+ new_t[:, :n_parts_i, :max_len_i] = t
2073
+ t = new_t
2074
+
2075
+ # pad chiều Nk
2076
+ if Nk < max_Nk:
2077
+ pad_tensor = t.new_full(
2078
+ (max_Nk - Nk, max_n_parts, max_len),
2079
+ pad_value
2080
+ )
2081
+ t = torch.cat([t, pad_tensor], dim=0)
2082
+
2083
+ padded.append(t)
2084
+
2085
+ return torch.stack(padded) # (B, max_Nk, max_n_parts, max_len)
2086
+
2087
+ def collate_fn(batch):
2088
+ gold_entities = []
2089
+ for bidx, b in enumerate(batch):
2090
+ for entity in b['gold_entities']:
2091
+ gold_entities.append([bidx, entity])
2092
+
2093
+ input_ids = [b["input_ids"].unsqueeze(-1) for b in batch]
2094
+ attention_mask = [b["attention_mask"].unsqueeze(-1) for b in batch]
2095
+
2096
+ all_spans = [b["all_spans"].unsqueeze(-1) for b in batch]
2097
+ all_labels = [b["all_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
2098
+
2099
+ gold_spans = [b["gold_spans"].unsqueeze(-1) for b in batch]
2100
+ gold_labels = [b["gold_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
2101
+
2102
+ start_labels = [b["start_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
2103
+ end_labels = [b["end_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
2104
+
2105
+ # pad theo Nk
2106
+ input_ids = _pad_batch(input_ids, pad_value=0).squeeze(-1)
2107
+ attention_mask = _pad_batch(attention_mask, pad_value=0).squeeze(-1)
2108
+
2109
+ all_spans = _pad_batch(all_spans, pad_value=0).squeeze(-1)
2110
+ all_labels = _pad_batch(all_labels, pad_value=-100).squeeze(-1).squeeze(-1)
2111
+
2112
+ gold_spans = _pad_batch(gold_spans, pad_value=0).squeeze(-1)
2113
+ gold_labels = _pad_batch(gold_labels, pad_value=-100).squeeze(-1).squeeze(-1)
2114
+
2115
+ start_labels = _pad_batch(start_labels, pad_value=-100).squeeze(-1).squeeze(-1)
2116
+ end_labels = _pad_batch(end_labels, pad_value=-100).squeeze(-1).squeeze(-1)
2117
+
2118
+ return {
2119
+ "input_ids": input_ids,
2120
+ "attention_mask": attention_mask,
2121
+
2122
+ "all_spans": all_spans,
2123
+ "all_labels": all_labels,
2124
+
2125
+ "gold_spans": gold_spans,
2126
+ "gold_labels": gold_labels,
2127
+
2128
+ "start_labels": start_labels,
2129
+ "end_labels": end_labels,
2130
+
2131
+ "gold_entities": gold_entities,
2132
+ }
2133
+
2134
+ # %% [code]
2135
+ def shift_bidx(spans, batch_idx):
2136
+ shifted = []
2137
+ for bidx, ent in spans:
2138
+ new_bidx = bidx + batch_idx * batch_size
2139
+ shifted.append((new_bidx, ent))
2140
+ return shifted
2141
+
2142
+ def refactor_entities(entities, save_dict):
2143
+ i, c = [], []
2144
+ for bidx, (ids, lb) in entities:
2145
+ if (bidx, ids) not in i:
2146
+ i.append((bidx, ids))
2147
+
2148
+ if (bidx, (ids, lb)) not in c:
2149
+ c.append((bidx, (ids, lb)))
2150
+
2151
+ save_dict['Ent-I'].extend(i)
2152
+ save_dict['Ent-C'].extend(c)
2153
+
2154
+ def check_spans(
2155
+ all_spans, # (N, 2)
2156
+ all_labels, # (N,)
2157
+ ensemble_start_logits, # (L, C)
2158
+ ensemble_end_logits, # (L, C)
2159
+ expand_dict,
2160
+ max_expand=10
2161
+ ):
2162
+ """
2163
+ Check minimum expansion radius required to achieve recall = 1.
2164
+
2165
+ Args:
2166
+ all_spans: gold spans
2167
+ all_labels: gold labels
2168
+ ensemble_start_logits: (L, C)
2169
+ ensemble_end_logits: (L, C)
2170
+ expand_dict:
2171
+ dict like:
2172
+ {
2173
+ 0: count,
2174
+ 1: count,
2175
+ ...
2176
+ }
2177
+
2178
+ Return:
2179
+ matched_expand_k
2180
+ """
2181
+
2182
+ # =========================================================
2183
+ # Gold spans
2184
+ # =========================================================
2185
+
2186
+ gold_set = set()
2187
+
2188
+ valid = all_labels > 0
2189
+
2190
+ valid_idxes = valid.nonzero(as_tuple=False).squeeze(-1)
2191
+
2192
+ for idx in valid_idxes:
2193
+
2194
+ s = int(all_spans[idx, 0])
2195
+ e = int(all_spans[idx, 1])
2196
+
2197
+ gold_set.add((s, e))
2198
+
2199
+ # no gold
2200
+ if len(gold_set) == 0:
2201
+ expand_dict[0] += 1
2202
+ return 0
2203
+
2204
+ # =========================================================
2205
+ # Decode base spans
2206
+ # =========================================================
2207
+
2208
+ start_labels = ensemble_start_logits.argmax(dim=-1) # (L,)
2209
+ end_labels = ensemble_end_logits.argmax(dim=-1) # (L,)
2210
+
2211
+ L = start_labels.shape[0]
2212
+
2213
+ base_pred = []
2214
+
2215
+ used_start = set()
2216
+ used_end = set()
2217
+
2218
+ for s in range(L):
2219
+
2220
+ s_label = start_labels[s].item()
2221
+
2222
+ if s_label == 0:
2223
+ continue
2224
+
2225
+ if s in used_start:
2226
+ continue
2227
+
2228
+ nearest_e = None
2229
+
2230
+ for e in range(s, L):
2231
+
2232
+ if e in used_end:
2233
+ continue
2234
+
2235
+ e_label = end_labels[e].item()
2236
+
2237
+ if e_label == s_label:
2238
+ nearest_e = e
2239
+ break
2240
+
2241
+ if nearest_e is None:
2242
+ continue
2243
+
2244
+ used_start.add(s)
2245
+ used_end.add(nearest_e)
2246
+
2247
+ base_pred.append((s, nearest_e))
2248
+
2249
+ # =========================================================
2250
+ # Try expansion radius
2251
+ # =========================================================
2252
+
2253
+ for k in range(max_expand + 1):
2254
+
2255
+ pred_set = set()
2256
+
2257
+ for s, e in base_pred:
2258
+
2259
+ for ds in range(-k, k + 1):
2260
+ for de in range(-k, k + 1):
2261
+
2262
+ ns = s + ds
2263
+ ne = e + de
2264
+
2265
+ if ns > 0 and ne > 0 and ns <= ne:
2266
+ pred_set.add((ns, ne))
2267
+
2268
+ # recall = 1
2269
+ if gold_set.issubset(pred_set):
2270
+
2271
+ expand_dict[k] += 1
2272
+ return k
2273
+
2274
+ # =========================================================
2275
+ # cannot recover
2276
+ # =========================================================
2277
+
2278
+ expand_dict[-1] = expand_dict.get(-1, 0) + 1
2279
+
2280
+ return -1
2281
+
2282
+ def test(network, state_dicts, test_loader, eval_fn, analyzer, device, id2label, tokenizer):
2283
+ if torch.cuda.device_count() > 1:
2284
+ network = DataParallelProxy(network)
2285
+ network = network.to(device)
2286
+ network.eval()
2287
+
2288
+ eval_types = ['Ent-I', 'Ent-C']
2289
+
2290
+ all_pred = {eval_type: [] for eval_type in eval_types}
2291
+ all_gold = {eval_type: [] for eval_type in eval_types}
2292
+
2293
+ list_input_ids = []
2294
+ expand_dict = {i: 0 for i in range(13)}
2295
+
2296
+ with torch.no_grad():
2297
+ for batch_idx, batch in enumerate(test_loader):
2298
+ input_ids = batch['input_ids'].to(device)
2299
+ attention_mask = batch['attention_mask'].to(device)
2300
+ all_spans = batch['all_spans'].to(device)
2301
+ all_labels = batch['all_labels'].to(device)
2302
+ gold_entities = batch['gold_entities']
2303
+
2304
+ B, _, _ = input_ids.shape
2305
+ list_input_ids.extend(input_ids.reshape(B, -1).tolist())
2306
+
2307
+ list_hidden_states = []
2308
+ list_start_logits = []
2309
+ list_end_logits = []
2310
+ list_scores = []
2311
+ for sd in state_dicts:
2312
+ if torch.cuda.device_count() > 1:
2313
+ network.module.load_state_dict(sd)
2314
+ else:
2315
+ network.load_state_dict(sd)
2316
+
2317
+ hidden_states, attention_mask = network.encode(input_ids, attention_mask)
2318
+ start_logits, end_logits = network.get_logits(hidden_states)
2319
+ list_hidden_states.append(hidden_states)
2320
+ list_start_logits.append(start_logits)
2321
+ list_end_logits.append(end_logits)
2322
+
2323
+ ensemble_start_logits = torch.stack(list_start_logits, dim=0).mean(dim=0)
2324
+ ensemble_end_logits = torch.stack(list_end_logits, dim=0).mean(dim=0)
2325
+ spans, pred_labels = extract_spans_and_labels(ensemble_start_logits, ensemble_end_logits)
2326
+ B, N, _ = spans.shape
2327
+ r = torch.ones((B, N), device=spans.device, dtype=torch.long) * network.max_r
2328
+ expanded_spans = expand_spans(spans, r, attention_mask)
2329
+
2330
+ for sd, hidden_states in zip(state_dicts, list_hidden_states):
2331
+ if torch.cuda.device_count() > 1:
2332
+ network.module.load_state_dict(sd)
2333
+ else:
2334
+ network.load_state_dict(sd)
2335
+ expanded_span_reprs = get_span_reprs(hidden_states, expanded_spans)
2336
+ expanded_span_scores = network.get_scores(expanded_span_reprs)
2337
+ list_scores.append(expanded_span_scores)
2338
+
2339
+ ensemble_scores = torch.stack(list_scores, dim=0).mean(dim=0)
2340
+ pred_spans = select_best_spans(expanded_spans, ensemble_scores)
2341
+
2342
+ pred_entities = extract_entities(input_ids.reshape(B, -1), pred_spans, pred_labels, id2label)
2343
+ pred_entities = shift_bidx(pred_entities, batch_idx)
2344
+ refactor_entities(pred_entities, all_pred)
2345
+
2346
+ gold_entities = shift_bidx(gold_entities, batch_idx)
2347
+ refactor_entities(gold_entities, all_gold)
2348
+
2349
+ for b in range(B):
2350
+ check_spans(
2351
+ all_spans[b], # (N, 2)
2352
+ all_labels[b], # (N,)
2353
+ ensemble_start_logits[b], # (L, C)
2354
+ ensemble_end_logits[b], # (L, C)
2355
+ expand_dict,
2356
+ max_expand=10
2357
+ )
2358
+
2359
+ # ===== GLOBAL EVAL =====
2360
+ final_score = {}
2361
+ for eval_type in eval_types:
2362
+ score = eval_fn(list_to_tuple(all_pred[eval_type]), list_to_tuple(all_gold[eval_type]))
2363
+ final_score[eval_type] = score
2364
+
2365
+ analyze_result = analyzer.analyze(list_to_tuple(all_pred['Ent-I']), list_to_tuple(all_gold['Ent-I']))
2366
+
2367
+ # ===== PREDICT =====
2368
+ predictions = []
2369
+ for input_ids in list_input_ids:
2370
+ predictions.append([tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)])
2371
+ for bidx, (ids, lb) in all_pred['Ent-C']:
2372
+ predictions[bidx].append((tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True), lb))
2373
+
2374
+ return final_score, analyze_result, predictions, expand_dict
2375
+
2376
+ # %% [code]
2377
+ with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f:
2378
+ data_train = json.load(f)
2379
+
2380
+ with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f:
2381
+ data_test = json.load(f)
2382
+
2383
+ print('Train:', len(data_train))
2384
+ print('Test:', len(data_test))
2385
+
2386
+ # %% [code]
2387
+ entity_types = ['O'] + sorted(list(set([e['label'] for d in data_train + data_test for e in d['entities']])))
2388
+ # bio_entity_type = ['O'] + [f'{prefix}-{ent}' for ent in entity_types for prefix in ['B', 'I']]
2389
+ label2id = {l: i for i, l in enumerate(entity_types)}
2390
+ id2label = {i: l for l, i in label2id.items()}
2391
+
2392
+ # %% [code]
2393
+ zero_entities_idxes = []
2394
+ for idx, d in enumerate(data_train):
2395
+ if len(d['entities']) == 0:
2396
+ zero_entities_idxes.append(idx)
2397
+
2398
+ n_zero_entities_samples = len(zero_entities_idxes)
2399
+ n_has_entities_samples = len(data_train) - n_zero_entities_samples
2400
+
2401
+ random.seed(42)
2402
+ k = min(int(n_has_entities_samples * zero_entities_rate), len(zero_entities_idxes))
2403
+ sampled_zero_entities_idxes = random.sample(zero_entities_idxes, k)
2404
+
2405
+ new_data_train = []
2406
+ for idx, d in enumerate(data_train):
2407
+ if len(d['entities']) == 0:
2408
+ if idx in sampled_zero_entities_idxes:
2409
+ new_data_train.append(d)
2410
+ else:
2411
+ new_data_train.append(d)
2412
+ data_train = new_data_train
2413
+
2414
+ print('Train:', len(data_train))
2415
+
2416
+ # %% [code]
2417
+ if debug_only:
2418
+ data_train = data_train[:10]
2419
+ data_test = data_test[:10]
2420
+
2421
+ print('Train:', len(data_train))
2422
+ print('Test:', len(data_test))
2423
+
2424
+ # %% [code]
2425
+ tokenizer = AutoTokenizer.from_pretrained(backbone_model_name)
2426
+
2427
+ # %% [code]
2428
+ print('Experiment name:', state_dict_save_name)
2429
+
2430
+ # %% [code]
2431
+ if not test_only:
2432
+ full_idxes = np.array(range(len(data_train)))
2433
+ training_logs, best_models, last_models = [], [], []
2434
+ start_training_time = time.time()
2435
+ for seed in SEEDS:
2436
+ kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed)
2437
+ for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)):
2438
+ if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx:
2439
+ continue
2440
+ set_seed(seed)
2441
+
2442
+ train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx]
2443
+
2444
+ trainset = KLTNDataset(data_train, train_idxes, label2id, tokenizer, **train_memory_params)
2445
+ valset = KLTNDataset(data_train, val_idxes, label2id, tokenizer, **val_memory_params)
2446
+
2447
+ generator = torch.Generator()
2448
+ generator.manual_seed(seed)
2449
+ train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params)
2450
+ val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params)
2451
+
2452
+ my_model = IEModel(
2453
+ num_labels=len(label2id),
2454
+ **model_params
2455
+ )
2456
+ total_params = sum(p.numel() for p in my_model.parameters())
2457
+ print(f"Total params: {total_params:,}")
2458
+
2459
+ # optimizer, scheduler = configure_optimizers(my_model, optim_params, scheduler_params)
2460
+ encoder_params = set(map(id, my_model.encoder.parameters()))
2461
+ other_params = [
2462
+ p for p in my_model.parameters()
2463
+ if id(p) not in encoder_params
2464
+ ]
2465
+ optimizer = optim.AdamW([
2466
+ {"params": my_model.encoder.parameters(), "lr": 2e-5},
2467
+ {"params": other_params}
2468
+ ], lr=5e-4)
2469
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
2470
+
2471
+ loss_fn = CustomLoss(
2472
+ **loss_func_params
2473
+ )
2474
+ eval_fn = CustomEvalFn(**eval_func_params)
2475
+ trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}'
2476
+ trainer = Trainer(**trainer_params)
2477
+
2478
+ print(f'Start Training Fold {fold_idx}...')
2479
+ training_log, best_model, last_model = trainer.fit(
2480
+ my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, eval_fn,
2481
+ start_epoch=1, start_training_time=start_training_time, id2label=id2label
2482
+ )
2483
+
2484
+ training_logs.append(training_log)
2485
+ best_models.append(best_model)
2486
+ last_models.append(last_model)
2487
+
2488
+ # %% [code]
2489
+ def load_all_state_dicts(folder):
2490
+ files = []
2491
+
2492
+ for file in os.listdir(folder):
2493
+ if file.endswith(".pt") or file.endswith(".pth"):
2494
+ m = re.search(r"f(\d+)", file) # tìm f<số>
2495
+ if m:
2496
+ fold = int(m.group(1))
2497
+ files.append((fold, file))
2498
+
2499
+ # sort theo fold
2500
+ files.sort(key=lambda x: x[0])
2501
+
2502
+ state_dicts = []
2503
+ for fold, file in files:
2504
+ path = os.path.join(folder, file)
2505
+ print(f"Loading fold {fold}: {file}")
2506
+ state_dict = torch.load(path, map_location="cpu")
2507
+ state_dicts.append(state_dict)
2508
+
2509
+ return state_dicts
2510
+
2511
+ if test_only:
2512
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"])
2513
+ get_ipython().system('rm -rf .cache .gitattributes')
2514
+
2515
+ best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s")
2516
+ last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts")
2517
+
2518
+ # %% [code]
2519
+ os.makedirs(f'{checkpoints_dir}/results', exist_ok=True)
2520
+ testset = KLTNDataset(data_test, range(len(data_test)), label2id, tokenizer, **val_memory_params)
2521
+ generator = torch.Generator()
2522
+ test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params)
2523
+ eval_fn = CustomEvalFn(**eval_func_params)
2524
+ analyzer = SpanErrorAnalyzer()
2525
+ my_model = IEModel(
2526
+ num_labels=len(label2id),
2527
+ **model_params
2528
+ )
2529
+ total_params = sum(p.numel() for p in my_model.parameters())
2530
+ print(f"Total params: {total_params:,}")
2531
+
2532
+ # %% [code]
2533
+ start_time = time.time()
2534
+ result_test = None
2535
+ analyze_result = None
2536
+
2537
+ best_score, best_analyze_result, best_pred_test, expand_dict = test(my_model, best_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
2538
+ last_score, last_analyze_result, last_pred_test, _ = test(my_model, last_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
2539
+
2540
+ result_test = {"Best model": best_score, "Last model": last_score}
2541
+ analyze_result = {"Best model": best_analyze_result, "Last model": last_analyze_result}
2542
+ analyze_result_sumary = {"Best model": best_analyze_result['summary'], "Last model": last_analyze_result['summary']}
2543
+ pred_test = {"Best model": best_pred_test, "Last model": last_pred_test}
2544
+
2545
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f:
2546
+ json.dump(result_test, f, ensure_ascii=False, indent=2)
2547
+
2548
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_error_analyze_result.json", "w", encoding="utf-8") as f:
2549
+ json.dump(analyze_result, f, ensure_ascii=False, indent=2)
2550
+
2551
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_pred_test.json", "w", encoding="utf-8") as f:
2552
+ json.dump(pred_test, f, ensure_ascii=False, indent=2)
2553
+
2554
+ print('Test:', time.time() - start_time, 's --> Done!')
2555
+ print(json.dumps(analyze_result_sumary, ensure_ascii=False, indent=4))
2556
+
2557
+ # %% [code]
2558
+ expand_dict
2559
+
2560
+ # %% [code]
2561
+ expand_dict_sum = sum(list(expand_dict.values()))
2562
+ {key: value / expand_dict_sum for key, value in expand_dict.items()}
2563
+
2564
+ # %% [code]
2565
+ best_pred_test[:10]
2566
+
2567
+ # %% [code]
2568
+ last_pred_test[:10]
2569
+
2570
+ # %% [code]
2571
+ def dict_to_df(data):
2572
+ row_tuples = []
2573
+ row_values = []
2574
+
2575
+ metrics = ["precision", "recall", "f1"]
2576
+
2577
+ # Lấy model đầu tiên
2578
+ first_model = next(iter(data.values()))
2579
+
2580
+ # eval_keys
2581
+ eval_keys = list(first_model.keys())
2582
+
2583
+ for eval_key in eval_keys:
2584
+ row_tuples.append(eval_key)
2585
+ row = {}
2586
+
2587
+ for model_name, model_data in data.items():
2588
+ for metric in metrics:
2589
+ row[(model_name, metric)] = model_data[eval_key][metric]
2590
+
2591
+ row_values.append(row)
2592
+
2593
+ # ===== DataFrame =====
2594
+ df = pd.DataFrame(row_values)
2595
+
2596
+ # MultiIndex columns
2597
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
2598
+
2599
+ # Index
2600
+ df.index = pd.Index(row_tuples, name="evaluation")
2601
+
2602
+ # ===== Sort =====
2603
+ sort_keys = []
2604
+ if ("Best model", "f1") in df.columns:
2605
+ sort_keys.append(("Best model", "f1"))
2606
+ if ("Last model", "f1") in df.columns:
2607
+ sort_keys.append(("Last model", "f1"))
2608
+
2609
+ if sort_keys:
2610
+ df = df.sort_values(by=sort_keys, ascending=False)
2611
+
2612
+ return df
2613
+
2614
+ result_test_df = dict_to_df(result_test)
2615
+ result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx")
2616
+ result_test_df
2617
+
2618
+ # %% [code]
2619
+ key = ("Best model", "f1")
2620
+ result_test_df_best = result_test_df.sort_values(by=key, ascending=False).groupby(level="evaluation").head(1)
2621
+ result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx")
2622
+ result_test_df_best
2623
+
2624
+ # %% [code]
2625
+ def get_avg_best_score(logs):
2626
+ return float(np.mean([list(log.values())[-1]['best_score'] for log in logs]))
2627
+
2628
+ def get_avg_log(logs, epochs):
2629
+ avg_log = {}
2630
+
2631
+ for epoch in range(1, epochs + 1):
2632
+ val_score = 0.0
2633
+ train_loss = 0.0
2634
+ n_eval = 0
2635
+
2636
+ for idx in range(len(logs)):
2637
+ log = logs[idx].get(epoch, logs[idx].get(str(epoch)))
2638
+ if log is None:
2639
+ continue
2640
+
2641
+ val_score += log.get('val_score', 0.0)
2642
+ train_loss += log.get('train_loss', 0.0)
2643
+ n_eval += 1
2644
+
2645
+ if n_eval == 0:
2646
+ continue
2647
+
2648
+ avg_log[epoch] = {
2649
+ 'train_loss': train_loss / n_eval,
2650
+ 'val_score': val_score / n_eval if val_score != 0 else float('inf')
2651
+ }
2652
+
2653
+ return avg_log
2654
+
2655
+ def parse_label_key(label: str):
2656
+ try:
2657
+ first = float(label.split('_', 1)[0]) # số đầu: trước dấu _
2658
+ last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0])
2659
+ return first, last
2660
+ except:
2661
+ return (0, 0)
2662
+
2663
+ def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)):
2664
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
2665
+
2666
+ # ===== Plot Train Loss =====
2667
+ for name, log in logs_dict.items():
2668
+ epochs = sorted(log.keys())
2669
+ train_loss = [log[e]['train_loss'] for e in epochs]
2670
+ axes[0].plot(epochs, train_loss, label=name)
2671
+
2672
+ axes[0].set_xlabel('Epoch')
2673
+ axes[0].set_ylabel('Train Loss')
2674
+ axes[0].set_title('Training Loss')
2675
+ axes[0].grid(True)
2676
+
2677
+ # ===== Plot Validation Score =====
2678
+ for name, log in logs_dict.items():
2679
+ epochs = sorted(log.keys())
2680
+ val_score = [log[e]['val_score'] for e in epochs]
2681
+ axes[1].plot(epochs, val_score, label=name)
2682
+
2683
+ axes[1].set_xlabel('Epoch')
2684
+ axes[1].set_ylabel('Validation Score')
2685
+ axes[1].set_title('Validation Score')
2686
+ axes[1].grid(True)
2687
+
2688
+ # ===== Shared Legend =====
2689
+ handles, labels = axes[0].get_legend_handles_labels()
2690
+ pairs = list(zip(handles, labels))
2691
+ pairs_sorted = sorted(
2692
+ pairs,
2693
+ key=lambda x: parse_label_key(x[1])
2694
+ )
2695
+ handles_sorted, labels_sorted = zip(*pairs_sorted)
2696
+
2697
+ axes[0].legend(
2698
+ handles_sorted,
2699
+ labels_sorted,
2700
+ loc='center left',
2701
+ bbox_to_anchor=(1.01, 0.5),
2702
+ borderaxespad=0.
2703
+ )
2704
+
2705
+ plt.tight_layout(rect=[0, 0, 1, 1])
2706
+
2707
+ if save_path is not None:
2708
+ os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None
2709
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
2710
+
2711
+ plt.show()
2712
+
2713
+ # %% [code]
2714
+ if not test_only:
2715
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*entities*.json"])
2716
+ get_ipython().system('rm -rf .cache .gitattributes')
2717
+
2718
+ # %% [code]
2719
+ if not test_only:
2720
+ experiments = {}
2721
+ for experiment in os.listdir(pretrained_dir):
2722
+ if '.virtual_documents' in experiment:
2723
+ continue
2724
+ experiment_logs = []
2725
+ try:
2726
+ for seed in SEEDS:
2727
+ for fold_idx in range(nfolds):
2728
+ with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f:
2729
+ experiment_log = json.load(f)
2730
+ experiment_logs.append(experiment_log)
2731
+ except:
2732
+ pass
2733
+ experiments[experiment] = get_avg_log(experiment_logs, 1000)
2734
+ experiments[state_dict_save_name] = get_avg_log(training_logs, 1000)
2735
+
2736
+ # %% [code]
2737
+ if not test_only:
2738
+ score = get_avg_best_score(training_logs)
2739
+ state_dict_save_name, score
2740
+
2741
+ # %% [code]
2742
+ if not test_only:
2743
+ plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5))
2744
+
4.2_add_span_rerank_branch_phoner_contrastive_4.5/lasts/4.2_add_span_rerank_branch_phoner_contrastive_4.5_s26092004_f0_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47bb0327640eecaa57dc994b25d79a145c09a9d3fc443407495df769c78ff58d
3
+ size 554324761
4.2_add_span_rerank_branch_phoner_contrastive_4.5/logs/4.2_add_span_rerank_branch_phoner_contrastive_4.5_log_plot.jpg ADDED

Git LFS Details

  • SHA256: c251abd26b60c87486698d3185c3ae8749794184d76fa7068eb4c94aba9171e9
  • Pointer size: 131 Bytes
  • Size of remote file: 581 kB
4.2_add_span_rerank_branch_phoner_contrastive_4.5/logs/4.2_add_span_rerank_branch_phoner_contrastive_4.5_s26092004_f0_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 0.0005], "train_loss": 1.0824006795883179, "total": 1.0824006687511096, "token_margin_loss": 0.08402876420454546, "span_margin_loss": 0.208984375, "contrastive_loss": 0.7898625460537997, "start_margin": 0.040327592329545456, "end_margin": 0.043523615056818184}, "2": {"lr": [1.988303923565381e-05, 0.0004969282409784868], "train_loss": 0.704765260219574, "total": 0.7047652331265536, "token_margin_loss": 0.032049005681818184, "span_margin_loss": 0.11479048295454546, "contrastive_loss": 0.5580862652171742, "start_margin": 0.015181107954545454, "end_margin": 0.016823508522727272}, "3": {"lr": [1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 0.6136787533760071, "total": 0.6136787154457786, "token_margin_loss": 0.025656960227272728, "span_margin_loss": 0.08442826704545454, "contrastive_loss": 0.5035970427773215, "start_margin": 0.012140447443181818, "end_margin": 0.013538707386363636}, "4": {"lr": [1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 0.5561001896858215, "total": 0.556100151755593, "token_margin_loss": 0.02130681818181818, "span_margin_loss": 0.064453125, "contrastive_loss": 0.470334529876709, "start_margin": 0.010181773792613636, "end_margin": 0.011086203835227272}, "5": {"lr": [1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 0.5115218162536621, "total": 0.5115217729048296, "token_margin_loss": 0.018022017045454544, "span_margin_loss": 0.05055930397727273, "contrastive_loss": 0.4429710128090598, "start_margin": 0.008839000355113636, "end_margin": 0.00921630859375}, "6": {"lr": [1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 0.46905893087387085, "total": 0.4690589037808505, "token_margin_loss": 0.015647194602272728, "span_margin_loss": 0.03670987215909091, "contrastive_loss": 0.41658986698497424, "start_margin": 0.007590553977272727, "end_margin": 0.00814541903409091}, "7": {"lr": [1.60839598967785e-05, 0.00039715242044697206], "train_loss": 0.459891676902771, "total": 0.45989166606556287, "token_margin_loss": 0.01394930752840909, "span_margin_loss": 0.03034002130681818, "contrastive_loss": 0.4155573411421342, "start_margin": 0.006808194247159091, "end_margin": 0.007113370028409091}, "8": {"lr": [1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 0.44156908988952637, "total": 0.4415690682151101, "token_margin_loss": 0.01192959872159091, "span_margin_loss": 0.02483575994318182, "contrastive_loss": 0.4047806913202459, "start_margin": 0.005837180397727273, "end_margin": 0.006114612926136364, "val_score": 0.7549365975402133, "best_score": 0.7549365975402133, "new_best_model": true, "precision": 0.8117284476028694, "recall": 0.7067157154143974, "f1": 0.7549365975402133}, "9": {"lr": [1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 0.4241897165775299, "total": 0.42418969761241565, "token_margin_loss": 0.010265003551136364, "span_margin_loss": 0.019331498579545456, "contrastive_loss": 0.3945862596685236, "start_margin": 0.005021528764204545, "end_margin": 0.005235151811079545, "val_score": 0.7608342080984483, "best_score": 0.7608342080984483, "new_best_model": true, "precision": 0.8143406276979905, "recall": 0.7151952475780802, "f1": 0.7608342080984483}, "10": {"lr": [1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 0.4148719310760498, "total": 0.41487190940163354, "token_margin_loss": 0.009160822088068182, "span_margin_loss": 0.015957919034090908, "contrastive_loss": 0.38974220102483575, "start_margin": 0.004444469105113636, "end_margin": 0.004710804332386364, "val_score": 0.7579002431341856, "best_score": 0.7608342080984483, "new_best_model": false, "precision": 0.8128419530206556, "recall": 0.711186824765584, "f1": 0.7579002431341856}, "11": {"lr": [1.0500000000000003e-05, 0.0002505], "train_loss": 0.41513848304748535, "total": 0.4151384613730691, "token_margin_loss": 0.008250843394886364, "span_margin_loss": 0.0145263671875, "contrastive_loss": 0.39239090139215643, "start_margin": 0.004036643288352273, "end_margin": 0.004208651455965909, "val_score": 0.757576704547213, "best_score": 0.7608342080984483, "new_best_model": false, "precision": 0.8096272402445337, "recall": 0.7129860991220954, "f1": 0.757576704547213}, "12": {"lr": [9.013872582117811e-06, 0.00021146960097246246], "train_loss": 0.4068813920021057, "total": 0.4068813757462935, "token_margin_loss": 0.007379705255681818, "span_margin_loss": 0.011951793323863636, "contrastive_loss": 0.38753557205200195, "start_margin": 0.0036426890980113635, "end_margin": 0.003750887784090909, "val_score": 0.7622329143046932, "best_score": 0.7622329143046932, "new_best_model": true, "precision": 0.8139959201942257, "recall": 0.7178806139498243, "f1": 0.7622329143046932}, "13": {"lr": [7.564338553438001e-06, 0.00017340025990345064], "train_loss": 0.39867931604385376, "total": 0.39867929978804156, "token_margin_loss": 0.006458629261363636, "span_margin_loss": 0.010747736150568182, "contrastive_loss": 0.38149538907137787, "start_margin": 0.003107244318181818, "end_margin": 0.003354159268465909, "val_score": 0.7580825878071814, "best_score": 0.7622329143046932, "new_best_model": false, "precision": 0.8094901261480825, "recall": 0.7143271444198388, "f1": 0.7580825878071814}, "14": {"lr": [6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 0.3968024253845215, "total": 0.3968024253845215, "token_margin_loss": 0.006025834517045455, "span_margin_loss": 0.009160822088068182, "contrastive_loss": 0.3816223578019576, "start_margin": 0.002960205078125, "end_margin": 0.0030628551136363635, "val_score": 0.7620497971505112, "best_score": 0.7622329143046932, "new_best_model": false, "precision": 0.8096285819146598, "recall": 0.7210915525418408, "f1": 0.7620497971505112}, "15": {"lr": [4.916040103221507e-06, 0.00010384757955302797], "train_loss": 0.3997964859008789, "total": 0.3997964859008789, "token_margin_loss": 0.005426580255681818, "span_margin_loss": 0.008195356889204546, "contrastive_loss": 0.3861376155506481, "start_margin": 0.002634221857244318, "end_margin": 0.0027992942116477275, "val_score": 0.7616070883261663, "best_score": 0.7622329143046932, "new_best_model": false, "precision": 0.8092171944653898, "recall": 0.7205119507206753, "f1": 0.7616070883261663}, "16": {"lr": [3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 0.3915356695652008, "total": 0.39153567227450287, "token_margin_loss": 0.005154696377840909, "span_margin_loss": 0.007451837713068182, "contrastive_loss": 0.3789240230213512, "start_margin": 0.002523248845880682, "end_margin": 0.002645319158380682, "val_score": 0.7640879116560432, "best_score": 0.7640879116560432, "new_best_model": true, "precision": 0.813630049681079, "recall": 0.7213080412102193, "f1": 0.7640879116560432}, "17": {"lr": [2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 0.3870449960231781, "total": 0.38704499331387604, "token_margin_loss": 0.004741321910511364, "span_margin_loss": 0.006675026633522727, "contrastive_loss": 0.3756334564902566, "start_margin": 0.002306851473721591, "end_margin": 0.0024261474609375, "val_score": 0.7619674593230111, "best_score": 0.7640879116560432, "new_best_model": false, "precision": 0.810926507997774, "recall": 0.719896708677494, "f1": 0.7619674593230111}, "18": {"lr": [2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 0.3896758556365967, "total": 0.3896758339621804, "token_margin_loss": 0.004494406960227273, "span_margin_loss": 0.005831631747159091, "contrastive_loss": 0.37934051860462537, "start_margin": 0.0022361061789772725, "end_margin": 0.0022569136186079545, "val_score": 0.7670054663430533, "best_score": 0.7670054663430533, "new_best_model": true, "precision": 0.8138082904426163, "recall": 0.7265197860397485, "f1": 0.7670054663430533}}
4.2_add_span_rerank_branch_phoner_contrastive_4.5/r1s/4.2_add_span_rerank_branch_phoner_contrastive_4.5_s26092004_f0_r1_vs0.76701_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7836a38e86c7b664097782787682c04d860a6a0dcac8dd0aaeff1ce45ec01a1
3
+ size 554327073
4.2_add_span_rerank_branch_phoner_contrastive_4.5/results/4.2_add_span_rerank_branch_phoner_contrastive_4.5_error_analyze_result.json ADDED
The diff for this file is too large to render. See raw diff
 
4.2_add_span_rerank_branch_phoner_contrastive_4.5/results/4.2_add_span_rerank_branch_phoner_contrastive_4.5_pred_test.json ADDED
The diff for this file is too large to render. See raw diff
 
4.2_add_span_rerank_branch_phoner_contrastive_4.5/results/4.2_add_span_rerank_branch_phoner_contrastive_4.5_test.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Best model": {
3
+ "Ent-I": {
4
+ "precision": 0.946609795556778,
5
+ "recall": 0.5786208756625203,
6
+ "f1": 0.7182234078558108
7
+ },
8
+ "Ent-C": {
9
+ "precision": 0.6868504292807424,
10
+ "recall": 0.545585068197644,
11
+ "f1": 0.6081216193899489
12
+ }
13
+ },
14
+ "Last model": {
15
+ "Ent-I": {
16
+ "precision": 0.946609795556778,
17
+ "recall": 0.5786208756625203,
18
+ "f1": 0.7182234078558108
19
+ },
20
+ "Ent-C": {
21
+ "precision": 0.6868504292807424,
22
+ "recall": 0.545585068197644,
23
+ "f1": 0.6081216193899489
24
+ }
25
+ }
26
+ }
4.2_add_span_rerank_branch_phoner_contrastive_4.5/results/4.2_add_span_rerank_branch_phoner_contrastive_4.5_test_df.xlsx ADDED
Binary file (5.23 kB). View file
 
4.2_add_span_rerank_branch_phoner_contrastive_4.5/results/4.2_add_span_rerank_branch_phoner_contrastive_4.5_test_df_best.xlsx ADDED
Binary file (5.23 kB). View file