SS3M commited on
Commit
faabb4b
·
verified ·
1 Parent(s): fe66fe5

Upload 2_lr_new_structure_3's state dict

Browse files
.gitattributes CHANGED
@@ -56,3 +56,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
56
  1_span_base_actions_6/logs/1_span_base_actions_6_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
57
  0_lr_1/logs/0_lr_1_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
58
  1_lr_add_bce_loss_2/logs/1_lr_add_bce_loss_2_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
 
 
56
  1_span_base_actions_6/logs/1_span_base_actions_6_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
57
  0_lr_1/logs/0_lr_1_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
58
  1_lr_add_bce_loss_2/logs/1_lr_add_bce_loss_2_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
59
+ 2_lr_new_structure_3/logs/2_lr_new_structure_3_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
2_lr_new_structure_3/2_lr_new_structure_3.py ADDED
@@ -0,0 +1,1566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [code]
2
+ get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch]')
3
+
4
+ # %% [code]
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
12
+ import torch.nn.functional as F
13
+ import albumentations as albu
14
+ from transformers import AutoTokenizer, AutoModel
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from positional_encodings.torch_encodings import PositionalEncoding1D
18
+
19
+ from sklearn.metrics import f1_score
20
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
21
+ from scipy.spatial.transform import Rotation as R
22
+ from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold
23
+ from sklearn.metrics import precision_recall_fscore_support
24
+ from timm.utils import ModelEmaV3
25
+ import timm
26
+
27
+ import os
28
+ import gc
29
+ import json
30
+ from pathlib import Path
31
+ import pickle
32
+ from tqdm.auto import tqdm
33
+ import copy
34
+ import numpy as np
35
+ import pandas as pd
36
+ import polars as pl
37
+ from PIL import Image
38
+ import time
39
+ from tqdm import tqdm
40
+ from matplotlib import pyplot as plt
41
+ import seaborn as sns
42
+ from multiprocessing import Manager as MemoryManager
43
+ from functools import lru_cache
44
+ import shutil
45
+ import glob
46
+ import cv2
47
+ import random
48
+ import re
49
+ import joblib
50
+ import math
51
+ from huggingface_hub import HfApi, snapshot_download
52
+ import evaluate
53
+ from underthesea import word_tokenize as vi_tokenize_tool
54
+ import spacy
55
+ en_tokenize_tool = spacy.load("en_core_web_sm")
56
+ from collections import defaultdict, Counter
57
+
58
+ # %% [code]
59
+ # Global config
60
+ SEEDS = [26092004]
61
+ topk = 1
62
+ nfolds = 5
63
+ only_fold_idx = 0
64
+ test_only = 0
65
+ debug_only = 0
66
+
67
+ # Config thư mục
68
+ dataset = 'kltn/raw' # vhe, bkee, casie, kltn/only_issues, kltn/only_actions, kltn/raw
69
+ root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' ## Thư mục chứa file train, val, test
70
+ train_dir = f'{root_dir}'
71
+ # val_dir = f'{root_dir}/val'
72
+ test_dir = f'{root_dir}'
73
+
74
+ # Config checkpoints
75
+
76
+ # Config training
77
+ epochs = 18 if not debug_only else 2
78
+ batch_size = 32
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+ # # Thêm biến toàn cục nào đó vào đây
81
+ repo_name = 'SS3M/kltn-experiments'
82
+ state_dict_save_name = "2_lr_new_structure_3"
83
+ checkpoints_dir = state_dict_save_name
84
+ pretrained_dir = "/kaggle/working"
85
+ os.makedirs(f'{checkpoints_dir}', exist_ok=True)
86
+
87
+ backbone_model_name = "bert-base-uncased" if dataset == "casie" else "vinai/phobert-base"
88
+ word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == "casie" else vi_tokenize_tool(text)
89
+ max_len_dict = {
90
+ 'kltn/raw': 256,
91
+ 'kltn/only_issues': 52,
92
+ 'kltn/only_actions': 69,
93
+ 'vhe': 51,
94
+ 'bkee': 62,
95
+ 'casie': 40,
96
+ }
97
+ zero_events_rate_dict = {
98
+ 'kltn/raw': 1000,
99
+ 'kltn/only_issues': 0,
100
+ 'kltn/only_actions': 0.2,
101
+ 'vhe': 1000, # mean keep all zero-events samples
102
+ 'bkee': 1000,
103
+ 'casie': 1000,
104
+ }
105
+
106
+ max_len = max_len_dict[dataset]
107
+ max_n_parts = 2
108
+ max_span_len = 14
109
+ n_negs = 5 * 20
110
+ zero_events_rate = zero_events_rate_dict[dataset]
111
+
112
+ # Trainer
113
+ trainer_params = {
114
+ "training_time": "00:11:30:00",
115
+ "eval_mode": "max",
116
+ "topk": topk,
117
+ "save_name": state_dict_save_name,
118
+ "save_best": True,
119
+ "save_last": True,
120
+ "device": device,
121
+ "logging": True,
122
+ "logging_file": True,
123
+ "checkpoints_dir": checkpoints_dir,
124
+ "early_stopping": 30,
125
+ "eval_from_ratio": 0.4,
126
+ "eval_every": 1,
127
+ "schedule_in_step": False,
128
+ "use_ema": True,
129
+ "ema_from_ratio": 0.3,
130
+ "ema_decay": 0.9995,
131
+ "max_grad_norm": 200.0,
132
+ "return_best": True,
133
+ "return_last": True,
134
+ }
135
+
136
+ # Memory
137
+ train_memory_params = {
138
+ 'max_len': max_len,
139
+ 'max_n_parts': max_n_parts,
140
+ 'n_negs': n_negs,
141
+ }
142
+ val_memory_params = {
143
+ 'max_len': max_len,
144
+ 'max_n_parts': max_n_parts,
145
+ 'n_negs': n_negs,
146
+ }
147
+ corpus_memory_params = {
148
+ 'max_len': max_len,
149
+ 'max_n_parts': max_n_parts,
150
+ }
151
+
152
+ # Data Loader
153
+ def seed_worker(worker_id):
154
+ worker_seed = torch.initial_seed() % 2**32
155
+ np.random.seed(worker_seed)
156
+ random.seed(worker_seed)
157
+
158
+ train_loader_params = {
159
+ 'batch_size': batch_size,
160
+ 'shuffle': True,
161
+ 'pin_memory':True,
162
+ 'num_workers': 2,
163
+ 'drop_last': False,
164
+ 'worker_init_fn': seed_worker,
165
+ 'persistent_workers': False,
166
+ }
167
+ val_loader_params = {
168
+ 'batch_size': batch_size,
169
+ 'shuffle': False,
170
+ 'pin_memory':True,
171
+ 'num_workers': 1,
172
+ 'drop_last': False,
173
+ 'worker_init_fn': seed_worker,
174
+ 'persistent_workers': False,
175
+ }
176
+
177
+ # Model
178
+ model_params = {
179
+ 'backbone_name': backbone_model_name,
180
+ 'projection_dim': 256,
181
+ 'normalize': True,
182
+ }
183
+
184
+ # Loss Func
185
+ loss_func_params = {
186
+ 'lambda_contrastive': 1.0,
187
+ 'lambda_triplet': 5.0,
188
+ 'lambda_bce': 3.0,
189
+ }
190
+ eval_func_params = {}
191
+
192
+ # Optim
193
+ optim_params = {
194
+ 'name': 'AdamW',
195
+ 'lr': 1e-4,
196
+ 'weight_decay': 1e-4,
197
+ }
198
+ scheduler_params = {
199
+ 'name': 'CosineAnnealingLR',
200
+ 'T_max': 20, # Số epoch để hoàn thành một chu kỳ giảm LR
201
+ 'eta_min': 1e-6 # Learning rate nhỏ nhất trong chu kỳ
202
+ }
203
+
204
+ # %% [code]
205
+ def set_seed(seed=42):
206
+ random.seed(seed)
207
+ np.random.seed(seed)
208
+ torch.manual_seed(seed)
209
+ torch.cuda.manual_seed(seed)
210
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
211
+ torch.use_deterministic_algorithms(False)
212
+ torch.backends.cudnn.deterministic = True
213
+ torch.backends.cudnn.benchmark = False
214
+ os.environ['PYTHONHASHSEED'] = str(seed)
215
+
216
+ # %% [code]
217
+ class CustomLoss(nn.Module):
218
+ def __init__(
219
+ self,
220
+ temperature=0.05,
221
+ margin=0.2,
222
+ lambda_contrastive=1.0,
223
+ lambda_triplet=0.5,
224
+ lambda_bce=1.0,
225
+ ):
226
+ super().__init__()
227
+
228
+ self.temperature = temperature
229
+ self.margin = margin
230
+
231
+ self.lambda_contrastive = lambda_contrastive
232
+ self.lambda_triplet = lambda_triplet
233
+ self.lambda_bce = lambda_bce
234
+
235
+ self.bce = nn.BCEWithLogitsLoss()
236
+
237
+ def forward(
238
+ self,
239
+ encoded_text, encoded_pos, encoded_neg, pos_mask,
240
+ logits, labels
241
+ ):
242
+ loss_contrastive = self.multi_pos_contrastive_loss(encoded_text, encoded_pos, encoded_neg, pos_mask)
243
+ loss_triplet = self.hardest_triplet_loss(encoded_text, encoded_pos, encoded_neg, pos_mask)
244
+ loss_bce = self.bce(logits[labels != -100], labels[labels != -100].float())
245
+
246
+ total_loss = (
247
+ self.lambda_contrastive * loss_contrastive +
248
+ self.lambda_triplet * loss_triplet +
249
+ self.lambda_bce * loss_bce
250
+ )
251
+
252
+ return {
253
+ "total": total_loss,
254
+ "contrastive_loss": loss_contrastive,
255
+ "triplet_loss": loss_triplet,
256
+ "loss_bce": loss_bce,
257
+ }
258
+
259
+ def multi_pos_contrastive_loss(self, q, pos, neg, pos_mask):
260
+ B, P, D = pos.shape
261
+ N = neg.shape[1]
262
+
263
+ # ===== concat docs =====
264
+ docs = torch.cat([pos, neg], dim=1) # [B, P+N, D]
265
+
266
+ # ===== similarity =====
267
+ logits = torch.matmul(q.unsqueeze(1), docs.transpose(1, 2)).squeeze(1)
268
+ logits = logits / self.temperature # [B, P+N]
269
+
270
+ # ===== labels =====
271
+ labels = torch.zeros_like(logits)
272
+ labels[:, :P] = pos_mask # chỉ pos hợp lệ
273
+
274
+ # ===== log-softmax =====
275
+ log_prob = logits - torch.logsumexp(logits, dim=1, keepdim=True)
276
+
277
+ # ===== normalize theo số pos thật =====
278
+ pos_count = pos_mask.sum(dim=1).clamp(min=1)
279
+
280
+ loss = -(labels * log_prob).sum(dim=1) / pos_count
281
+
282
+ return loss.mean()
283
+
284
+ def hardest_triplet_loss(self, q, pos, neg, pos_mask):
285
+ # ===== similarity =====
286
+ pos_sim = torch.matmul(q.unsqueeze(1), pos.transpose(1, 2)).squeeze(1) # [B, P]
287
+ neg_sim = torch.matmul(q.unsqueeze(1), neg.transpose(1, 2)).squeeze(1) # [B, N]
288
+
289
+ # ===== mask pos =====
290
+ pos_sim_masked = pos_sim.clone()
291
+ pos_sim_masked[pos_mask == 0] = float('inf') # loại pad
292
+
293
+ # ===== hardest =====
294
+ hardest_pos = pos_sim_masked.min(dim=1).values
295
+ hardest_neg = neg_sim.max(dim=1).values
296
+
297
+ # ===== loss =====
298
+ loss = F.relu(self.margin + hardest_neg - hardest_pos)
299
+
300
+ return loss.mean()
301
+
302
+ # %% [code]
303
+ class CustomEvalFn(nn.Module):
304
+ def __init__(self):
305
+ super().__init__()
306
+
307
+ def forward(self, pred_topk, real_topk):
308
+ """
309
+ pred_topk: List[List[int]] shape [B, K]
310
+ real_topk: List[List[int]] shape [B, Ki]
311
+ """
312
+
313
+ B = len(pred_topk)
314
+
315
+ total_recall = 0.0
316
+ total_precision = 0.0
317
+ total_mrr = 0.0
318
+ total_f2 = 0.0
319
+
320
+ for i in range(B):
321
+ preds = pred_topk[i]
322
+ pred_set = set(preds)
323
+
324
+ gts = set(real_topk[i])
325
+
326
+ # ===== Recall@K =====
327
+ hit = any(p in gts for p in preds)
328
+ total_recall += 1.0 if hit else 0.0
329
+
330
+ # ===== MRR =====
331
+ rr = 0.0
332
+ for rank, p in enumerate(preds, start=1):
333
+ if p in gts:
334
+ rr = 1.0 / rank
335
+ break
336
+
337
+ total_mrr += rr
338
+
339
+ # ===== Precision / Recall / F2 =====
340
+ tp = len(pred_set & gts)
341
+
342
+ precision = tp / len(pred_set) if len(pred_set) > 0 else 0.0
343
+ recall_f = tp / len(gts) if len(gts) > 0 else 0.0
344
+
345
+ total_precision += precision
346
+
347
+ beta2 = 2 ** 2
348
+
349
+ if precision + recall_f > 0:
350
+ f2 = (1 + beta2) * precision * recall_f / (
351
+ beta2 * precision + recall_f
352
+ )
353
+ else:
354
+ f2 = 0.0
355
+
356
+ total_f2 += f2
357
+
358
+ recall = total_recall / B
359
+ precision = total_precision / B
360
+ mrr = total_mrr / B
361
+ f2 = total_f2 / B
362
+
363
+ return {
364
+ "precision": precision,
365
+ "recall": recall,
366
+ "f2": f2,
367
+ "mrr": mrr,
368
+ }
369
+
370
+ # %% [code]
371
+ class MLP(nn.Module):
372
+ def __init__(self, in_size, hid_size, out_size):
373
+ super().__init__()
374
+ self.mlp = nn.Sequential(
375
+ nn.Linear(in_size, hid_size),
376
+ nn.ReLU(),
377
+ nn.Linear(hid_size, out_size)
378
+ )
379
+
380
+ def forward(self, x):
381
+ return self.mlp(x)
382
+
383
+ class EncodeModel(nn.Module):
384
+ def __init__(self, backbone_name, projection_dim, normalize):
385
+ super().__init__()
386
+
387
+ self.encoder = AutoModel.from_pretrained(backbone_name)
388
+ hidden_size = self.encoder.config.hidden_size
389
+
390
+ self.proj = MLP(hidden_size, hidden_size, projection_dim)
391
+ self.classifier = MLP(2*projection_dim, projection_dim, 1)
392
+ self.normalize = normalize
393
+
394
+ def embed_query(self, input_ids, attention_mask):
395
+ B, n_parts, L = input_ids.shape
396
+ input_ids = input_ids.view(-1, L)
397
+ attention_mask = attention_mask.view(-1, L)
398
+
399
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
400
+ hidden = outputs.last_hidden_state # B * n_parts, L, H
401
+ hidden = hidden.view(B, n_parts, L, -1).mean(dim=1)
402
+ cls = hidden[:, 0]
403
+ embed = self.proj(cls)
404
+ return embed # B, D
405
+
406
+ def embed_doc(self, input_ids, attention_mask):
407
+ B, K, n_parts, L = input_ids.shape
408
+ input_ids = input_ids.view(-1, L)
409
+ attention_mask = attention_mask.view(-1, L)
410
+
411
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
412
+ hidden = outputs.last_hidden_state # B * K * n_parts, L, H
413
+ hidden = hidden.view(B, K, n_parts, L, -1).mean(dim=2)
414
+ cls = hidden[:, :, 0]
415
+ embed = self.proj(cls)
416
+ return embed # B, K, D
417
+
418
+ def doc_classify(self, q_emb, d_emb):
419
+ B, N, D = d_emb.shape
420
+ q_emb_expand = q_emb.unsqueeze(1).expand(-1, N, -1)
421
+ q_d_emb = torch.cat([q_emb_expand, d_emb], dim=-1) # (B, N, 2D)
422
+ logits = self.classifier(q_d_emb).squeeze(-1) # (B, N)
423
+ return logits
424
+
425
+ def forward(self, text_input_ids, text_attn_mask, doc_input_ids, doc_attn_mask):
426
+ q_emb = self.embed_query(text_input_ids, text_attn_mask)
427
+ d_emb = self.embed_doc(doc_input_ids, doc_attn_mask)
428
+ if self.normalize:
429
+ q_emb = F.normalize(q_emb, dim=-1)
430
+ d_emb = F.normalize(d_emb, dim=-1)
431
+ logits = self.doc_classify(q_emb, d_emb)
432
+
433
+ return q_emb, d_emb, logits
434
+
435
+ def test_model():
436
+ model = nn.DataParallel(EncodeModel('vinai/phobert-base', 256, True)).to(device)
437
+ model.eval()
438
+
439
+ bz = 32
440
+ vocab_size = 1000
441
+ qi = torch.randint(0, vocab_size, (bz, 1, 256)).to(device)
442
+ qa = torch.ones(bz, 1, 256).to(device)
443
+ di = torch.randint(0, vocab_size, (bz, 5, 2, 256)).to(device)
444
+ da = torch.ones(bz, 5, 2, 256).to(device)
445
+
446
+ st = time.time()
447
+
448
+ with torch.no_grad():
449
+ r = model(qi, qa, di, da)
450
+
451
+ if type(r) == tuple:
452
+ print([r[i].shape for i in range(len(r))])
453
+ else:
454
+ print(r.shape)
455
+ print(time.time() - st)
456
+
457
+ del model, qi, qa, di, da
458
+ torch.cuda.empty_cache()
459
+ gc.collect()
460
+ test_model()
461
+
462
+ # %% [code]
463
+ def configure_optimizers(network, optim_params, scheduler_params):
464
+ try:
465
+ optim_params = copy.copy(optim_params)
466
+ scheduler_params = copy.copy(scheduler_params)
467
+
468
+ optim_name = optim_params.pop('name')
469
+ scheduler_name = scheduler_params.pop('name')
470
+
471
+ optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None)
472
+ scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None)
473
+
474
+ if optimizer_cls is None:
475
+ raise ValueError(f"Optimizer '{optim_name}' is not available!")
476
+
477
+ optimizer = optimizer_cls(network.parameters(), **optim_params)
478
+
479
+ scheduler = None
480
+ if scheduler_params and scheduler_cls: # Chỉ tạo scheduler nếu có tham số
481
+ scheduler = scheduler_cls(optimizer, **scheduler_params)
482
+
483
+ return optimizer, scheduler
484
+
485
+ except KeyError as e:
486
+ raise ValueError(f"Missing {e} in config!!")
487
+
488
+ def freeze(self, model):
489
+ model.eval()
490
+ for param in model.parameters():
491
+ param.requires_grad = False
492
+
493
+ def unfreeze(self, model):
494
+ model.train()
495
+ for param in model.parameters():
496
+ param.requires_grad = True
497
+
498
+ def reduce_batch_size(loader, ratio=0.5):
499
+ new_bs = max(1, int(loader.batch_size * ratio))
500
+
501
+ shuffle = isinstance(loader.sampler, RandomSampler)
502
+
503
+ new_loader = DataLoader(
504
+ dataset=loader.dataset,
505
+ batch_size=new_bs,
506
+ shuffle=shuffle,
507
+ sampler=None if shuffle else loader.sampler,
508
+ num_workers=loader.num_workers,
509
+ collate_fn=loader.collate_fn,
510
+ pin_memory=loader.pin_memory,
511
+ drop_last=loader.drop_last,
512
+ timeout=loader.timeout,
513
+ worker_init_fn=loader.worker_init_fn,
514
+ multiprocessing_context=loader.multiprocessing_context,
515
+ generator=loader.generator,
516
+ prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None,
517
+ persistent_workers=loader.persistent_workers,
518
+ pin_memory_device=loader.pin_memory_device
519
+ )
520
+
521
+ return new_loader
522
+
523
+ def list_to_tuple(x):
524
+ if isinstance(x, (list, tuple)):
525
+ return tuple(list_to_tuple(i) for i in x)
526
+ return x
527
+
528
+ def fmt(x):
529
+ if isinstance(x, float):
530
+ return round(x, 5)
531
+ if isinstance(x, dict):
532
+ return {k: fmt(v) for k, v in x.items()}
533
+ if isinstance(x, list):
534
+ return [fmt(v) for v in x]
535
+ return x
536
+
537
+ class ModelEmaV3Proxy(ModelEmaV3):
538
+ def __getattr__(self, name):
539
+ try:
540
+ return super().__getattr__(name)
541
+ except AttributeError:
542
+ return getattr(self.module, name)
543
+
544
+ class DataParallelProxy(nn.DataParallel):
545
+ def __getattr__(self, name):
546
+ try:
547
+ return super().__getattr__(name)
548
+ except AttributeError:
549
+ attr = getattr(self.module, name)
550
+
551
+ if callable(attr):
552
+ def wrapper(*args, **kwargs):
553
+ return self._parallel_apply_method(name, *args, **kwargs)
554
+ return wrapper
555
+
556
+ return attr
557
+
558
+ def _parallel_apply_method(self, method_name, *inputs, **kwargs):
559
+ if not self.device_ids:
560
+ return getattr(self.module, method_name)(*inputs, **kwargs)
561
+
562
+ inputs_scattered, kwargs_scattered = self.scatter(inputs, kwargs, self.device_ids)
563
+
564
+ replicas = self.replicate(self.module, self.device_ids)
565
+
566
+ outputs = self.parallel_apply(
567
+ [getattr(replica, method_name) for replica in replicas],
568
+ inputs_scattered,
569
+ kwargs_scattered
570
+ )
571
+
572
+ return self.gather(outputs, self.output_device)
573
+
574
+ class Trainer:
575
+ def __init__(
576
+ self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0,
577
+ logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu',
578
+ schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True
579
+ ):
580
+ self.ema_net = None
581
+
582
+ self.training_time = self._time_str_to_seconds(training_time)
583
+ self.mode = eval_mode
584
+ self.topk = topk
585
+ self.device = device
586
+ self.logging = logging if logging < epochs else 1
587
+ self.logging_file = logging_file
588
+ self.checkpoints_dir = checkpoints_dir
589
+ self.early_stopping = early_stopping
590
+ self.eval_from_ratio = eval_from_ratio
591
+ self.eval_every = eval_every
592
+ self.save_name = save_name
593
+ self.save_best = save_best
594
+ self.save_last = save_last
595
+ self.return_best = return_best
596
+ self.return_last = return_last
597
+ self.max_grad_norm = max_grad_norm
598
+ self.schedule_in_step = schedule_in_step
599
+ self.use_ema = use_ema
600
+ self.ema_from_ratio = ema_from_ratio
601
+ self.ema_decay = ema_decay
602
+
603
+ self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]]
604
+ self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0)
605
+
606
+ def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, corpus_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, refresh_every=3):
607
+ if eval_fn is None:
608
+ if self.mode == "max":
609
+ eval_fn = lambda *x: -loss_fn(*x)
610
+ else:
611
+ eval_fn = lambda *x: loss_fn(*x)
612
+
613
+ if torch.cuda.device_count() > 1:
614
+ network = DataParallelProxy(network)
615
+ network = network.to(self.device)
616
+
617
+ if not start_training_time:
618
+ start_training_time = time.time()
619
+
620
+ start_ema = int(epochs * self.ema_from_ratio)
621
+ start_eval = int(epochs * self.eval_from_ratio)
622
+
623
+ if val_loader is None:
624
+ print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!')
625
+ else:
626
+ model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc'
627
+ start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else ''
628
+ print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str)
629
+
630
+ training_log = {}
631
+ for epoch in range(start_epoch, epochs+start_epoch):
632
+ if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema:
633
+ self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
634
+
635
+ try:
636
+ eval_net = self.ema_net if (self.use_ema and self.ema_net is not None) else network
637
+ if (epoch - start_epoch) % refresh_every == 0:
638
+ encoded_docs = self._get_encoded_docs(eval_net, corpus_loader)
639
+ print(f"[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Refresh Encoded Doc (refresh_every={refresh_every})!")
640
+ elif (epoch - start_epoch - start_eval) % self.eval_every == 0 and epoch - start_epoch >= start_eval:
641
+ encoded_docs = self._get_encoded_docs(eval_net, corpus_loader)
642
+ print(f"[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Refresh Encoded Doc (eval_every={self.eval_every})!")
643
+
644
+ train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, encoded_docs)
645
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
646
+ logging_dict.update(train_loss_epoch_dict)
647
+
648
+ if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0:
649
+ val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, encoded_docs)
650
+ update = self._update_best_network(eval_net, val_score, epoch)
651
+ logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update})
652
+ logging_dict.update(val_score_dict)
653
+ if not self.schedule_in_step and scheduler:
654
+ scheduler.step()
655
+
656
+ except RuntimeError as e:
657
+ if "out of memory" in str(e).lower():
658
+ print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...")
659
+ torch.cuda.empty_cache()
660
+ gc.collect()
661
+ if torch.cuda.is_available():
662
+ torch.cuda.synchronize()
663
+ print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared")
664
+
665
+ train_loader = reduce_batch_size(train_loader, ratio=0.5)
666
+ if val_loader is not None:
667
+ val_loader = reduce_batch_size(val_loader, ratio=0.5)
668
+
669
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')}
670
+ else:
671
+ raise
672
+
673
+ training_log[epoch] = logging_dict
674
+ if self.is_early_stopping(epoch):
675
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...')
676
+ break
677
+ if self.logging:
678
+ if epoch % self.logging == 0:
679
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict))
680
+ else:
681
+ print(f'{epoch}...', end=' ')
682
+
683
+ if self._at_time_limit(start_training_time):
684
+ print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}')
685
+ break
686
+
687
+ if self.logging_file:
688
+ os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True)
689
+ with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f:
690
+ f.write(json.dumps(training_log))
691
+
692
+ if self.use_ema and self.ema_net is not None:
693
+ self._save_state_dict(self.ema_net.module)
694
+ else:
695
+ self._save_state_dict(network)
696
+ print(f'[Trainer CallBack] 📢 Kết thúc training.\n')
697
+
698
+ best_model, last_model = None, None
699
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
700
+ if self.return_best :
701
+ best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict()
702
+ best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()}
703
+ if self.return_last:
704
+ last_model = eval_net.state_dict()
705
+ last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()}
706
+
707
+ del network
708
+ torch.cuda.empty_cache()
709
+ gc.collect()
710
+ return training_log, best_model, last_model
711
+
712
+ def _time_str_to_seconds(self, time_str):
713
+ days, hours, minutes, seconds = map(int, time_str.split(":"))
714
+ return days * 86400 + hours * 3600 + minutes * 60 + seconds
715
+
716
+ def _update_best_network(self, network, val_score, epoch):
717
+ topk = max(1, self.topk)
718
+ self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}])
719
+ self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk]
720
+ if val_score in [x[0] for x in self.best_stage]:
721
+ return True
722
+ return False
723
+
724
+ def is_early_stopping(self, epoch):
725
+ if self.best_stage[0][1] is None:
726
+ return False
727
+ if not self.early_stopping:
728
+ return False
729
+ return epoch - self.best_stage[0][1] >= self.early_stopping
730
+
731
+ def _at_time_limit(self, start_training_time):
732
+ return time.time() - start_training_time >= self.training_time
733
+
734
+ def _save_state_dict(self, network):
735
+ if self.topk <= 0:
736
+ return
737
+
738
+ if self.save_best:
739
+ for r in range(self.topk):
740
+ os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True)
741
+
742
+ for rank, (score, epoch, state_dict) in enumerate(self.best_stage):
743
+ if state_dict is None:
744
+ continue
745
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()}
746
+ torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth')
747
+ if self.save_last:
748
+ os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True)
749
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()}
750
+ torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth')
751
+
752
+ def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn, encoded_docs):
753
+ network.train()
754
+ total_loss = 0
755
+ total_loss_dict = {}
756
+ for batch_idx, batch in enumerate(train_loader):
757
+ optimizer.zero_grad()
758
+ with torch.autocast(device_type=self.device, dtype=torch.float16):
759
+ loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn, encoded_docs)
760
+
761
+ for k, v in loss_dict.items():
762
+ t = total_loss_dict.get(k, 0)
763
+ total_loss_dict[k] = t + v
764
+ self.grad_scaler.scale(loss).backward()
765
+ self.grad_scaler.unscale_(optimizer)
766
+ grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm)
767
+ # print(grad_norm) # Bỏ cmt dòng này để biết nên chọn max_grad_norm bằng bao nhiêu...
768
+ self.grad_scaler.step(optimizer)
769
+ self.grad_scaler.update()
770
+ if self.schedule_in_step and scheduler:
771
+ scheduler.step()
772
+ if self.use_ema and self.ema_net is not None:
773
+ self.ema_net.update(network)
774
+ total_loss += loss
775
+ return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()}
776
+
777
+ def _eval_epoch(self, network, val_loader, eval_fn, encoded_docs):
778
+ network.eval()
779
+ total_score = 0.0
780
+ total_score_dict = {}
781
+ object_lists = None # sẽ init sau
782
+
783
+ with torch.no_grad():
784
+ for batch_idx, batch in enumerate(val_loader):
785
+ score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, encoded_docs)
786
+ total_score += score
787
+
788
+ for k, v in score_dict.items():
789
+ t = total_score_dict.get(k, 0)
790
+ total_score_dict[k] = t + v
791
+
792
+ if objects:
793
+ if object_lists is None:
794
+ object_lists = [[] for _ in range(len(objects))]
795
+
796
+ for i, obj in enumerate(objects):
797
+ object_lists[i].append(obj.detach())
798
+
799
+ if object_lists is not None:
800
+ object_arrays = [
801
+ torch.concat(obj_list, dim=0).cpu().numpy()
802
+ for obj_list in object_lists
803
+ ]
804
+ else:
805
+ object_arrays = []
806
+
807
+ return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays
808
+
809
+ def _get_encoded_docs(self, network, corpus_loader):
810
+ network.eval()
811
+ with torch.no_grad():
812
+ encoded_docs = []
813
+ for batch_idx, batch in enumerate(corpus_loader):
814
+ input_ids = batch['input_ids'].to(self.device)
815
+ attn_mask = batch['attn_mask'].to(self.device)
816
+ encoded_doc = network.embed_doc(input_ids, attn_mask)
817
+ if network.normalize:
818
+ encoded_doc = F.normalize(encoded_doc, dim=-1)
819
+ encoded_docs.append(encoded_doc)
820
+ encoded_docs = torch.concat(encoded_docs, dim=0).squeeze(1)
821
+ return encoded_docs
822
+
823
+ def _cal_loss(self, network, batch, batch_idx, loss_fn, encoded_docs):
824
+ # Bạn cần override _cal_loss để tính loss
825
+ text_input_ids = batch['text_input_ids'].to(self.device)
826
+ text_attn_mask = batch['text_attn_mask'].to(self.device)
827
+ pos_idxes = batch['pos_idxes'].to(self.device)
828
+ pos_mask = batch['pos_mask'].to(self.device)
829
+ neg_idxes = batch['neg_idxes'].to(self.device)
830
+
831
+ encoded_text = network.embed_query(text_input_ids, text_attn_mask)
832
+ if network.normalize:
833
+ encoded_text = F.normalize(encoded_text, dim=-1)
834
+ encoded_pos = encoded_docs[pos_idxes]
835
+ encoded_neg = encoded_docs[neg_idxes]
836
+
837
+ docs = torch.cat([encoded_pos, encoded_neg], dim=1)
838
+ labels = torch.cat([
839
+ torch.where(pos_mask > 0, pos_mask, torch.full_like(pos_mask, -100)),
840
+ torch.zeros(encoded_pos.size(0), encoded_neg.size(1), device=encoded_pos.device)
841
+ ], dim=1)
842
+
843
+ B, N, D = docs.shape
844
+ perm = torch.argsort(torch.rand(B, N, device=docs.device), dim=1)
845
+ docs = docs.gather(1, perm.unsqueeze(-1).expand(-1, -1, D))
846
+ logits = network.doc_classify(encoded_text, docs)
847
+ labels = labels.gather(1, perm)
848
+
849
+ loss_dict = loss_fn(
850
+ encoded_text, encoded_pos, encoded_neg, pos_mask,
851
+ logits, labels
852
+ )
853
+ return loss_dict['total'], loss_dict
854
+
855
+ def _cal_val_score(self, network, batch, batch_idx, eval_fn, encoded_docs):
856
+ # Bạn cần override _cal_val_score để tính val score, list bên cạnh là để trả về y hay pred gì đó (nếu cần)
857
+ text_input_ids = batch['text_input_ids'].to(self.device)
858
+ text_attn_mask = batch['text_attn_mask'].to(self.device)
859
+ gt_pos_idxes = batch['gt_pos_idxes']
860
+
861
+ encoded_text = network.embed_query(text_input_ids, text_attn_mask)
862
+ if network.normalize:
863
+ encoded_text = F.normalize(encoded_text, dim=-1)
864
+
865
+ B, _ = encoded_text.shape
866
+ expand_encoded_docs = encoded_docs.unsqueeze(0).expand(B, -1, -1)
867
+ logits = network.doc_classify(encoded_text, expand_encoded_docs)
868
+ scores = torch.matmul(encoded_text, encoded_docs.T)
869
+
870
+ _, topk_indices = torch.topk(scores, k=10, dim=-1) # [B, K]
871
+ topk_logits = logits.gather(1, topk_indices) # [B, K]
872
+ pred_topk = [idx[logit > 0].tolist() for idx, logit in zip(topk_indices, topk_logits)]
873
+
874
+ pred_topk = list_to_tuple(pred_topk)
875
+ gt_pos_idxes = list_to_tuple(gt_pos_idxes)
876
+ score_dict = eval_fn(pred_topk, gt_pos_idxes)
877
+ return score_dict['f2'], score_dict, []
878
+
879
+ # %% [code]
880
+ def tokenize_to_parts(text, tokenizer, max_len, max_n_parts):
881
+ # Tokenize với overflow để chia thành nhiều đoạn
882
+ enc = tokenizer(
883
+ text,
884
+ max_length=max_len*max_n_parts,
885
+ truncation=True,
886
+ padding="max_length",
887
+ return_overflowing_tokens=True,
888
+ return_tensors="pt"
889
+ )
890
+
891
+ input_ids = enc["input_ids"].reshape(max_n_parts, max_len) # (n_parts, max_len)
892
+ attn_mask = enc["attention_mask"].reshape(max_n_parts, max_len) # (n_parts, max_len)
893
+
894
+ return input_ids, attn_mask
895
+
896
+ class LawRetrievalDataset(Dataset):
897
+ def __init__(self, all_data, using_idxes, corpus_dict, tokenizer, max_len, max_n_parts, n_negs):
898
+ super().__init__()
899
+
900
+ self.all_data = all_data
901
+ self.using_idxes = using_idxes
902
+ self.tokenizer = tokenizer
903
+ self.max_len = max_len
904
+ self.max_n_parts = max_n_parts
905
+ self.n_negs = n_negs
906
+
907
+ # ===== BUILD CORPUS =====
908
+ idx = 0
909
+ self.corpus_list = []
910
+ self.corpus_dict = {}
911
+
912
+ for doc_name, articles_dict in corpus_dict.items():
913
+ self.corpus_dict[doc_name] = {}
914
+ for article_idx, content in articles_dict.items():
915
+ self.corpus_list.append([doc_name, article_idx, content])
916
+ self.corpus_dict[doc_name][article_idx] = {
917
+ 'content': content,
918
+ 'idx': idx
919
+ }
920
+ idx += 1
921
+
922
+ def __len__(self):
923
+ return len(self.using_idxes)
924
+
925
+ # ===== ENCODE DOC =====
926
+ def _encode_contexts(self, idxes):
927
+ all_input_ids, all_attn_mask = [], []
928
+
929
+ for idx in idxes:
930
+ name, art, _ = self.corpus_list[idx]
931
+ corpus = self.corpus_dict[name][art]
932
+
933
+ if 'content_input_ids' in corpus:
934
+ content_input_ids = corpus['content_input_ids']
935
+ content_attn_mask = corpus['content_attn_mask']
936
+ else:
937
+ content = corpus['content']
938
+ content_input_ids, content_attn_mask = tokenize_to_parts(
939
+ content, self.tokenizer, self.max_len, self.max_n_parts
940
+ )
941
+ corpus['content_input_ids'] = content_input_ids
942
+ corpus['content_attn_mask'] = content_attn_mask
943
+
944
+ all_input_ids.append(content_input_ids)
945
+ all_attn_mask.append(content_attn_mask)
946
+
947
+ all_input_ids = torch.stack(all_input_ids)
948
+ all_attn_mask = torch.stack(all_attn_mask)
949
+
950
+ return all_input_ids, all_attn_mask
951
+
952
+ def __getitem__(self, idx):
953
+ ridx = self.using_idxes[idx]
954
+ data = self.all_data[ridx]
955
+
956
+ query_text = data['text']
957
+
958
+ text_input_ids, text_attn_mask = tokenize_to_parts(
959
+ query_text, self.tokenizer, self.max_len, 1
960
+ )
961
+
962
+ # ===== POS =====
963
+ gt_pos_idxes = []
964
+ hard_names = []
965
+ for law in data['relevant_law']:
966
+ name = law['doc']
967
+ art = law['art']
968
+ gt_pos_idxes.append(self.corpus_dict[name][art]['idx'])
969
+ if name not in hard_names:
970
+ hard_names.append(name)
971
+
972
+ pos_idxes = torch.tensor(gt_pos_idxes, dtype=torch.long)
973
+ pos_mask = torch.ones(len(pos_idxes))
974
+
975
+ # ===== NEG =====
976
+ hard_neg_idxes = []
977
+ for name in hard_names:
978
+ for content in self.corpus_dict[name].values():
979
+ if content['idx'] in gt_pos_idxes:
980
+ continue
981
+ hard_neg_idxes.append(content['idx'])
982
+
983
+ easy_neg_idxes = list(range(len(self.corpus_list)))
984
+ for i in gt_pos_idxes + hard_neg_idxes:
985
+ if i in easy_neg_idxes:
986
+ easy_neg_idxes.remove(i)
987
+
988
+ n_hards = min(len(hard_neg_idxes), self.n_negs // 2)
989
+ neg_idxes = random.sample(hard_neg_idxes, n_hards) + random.sample(easy_neg_idxes, self.n_negs - n_hards)
990
+ neg_idxes = torch.tensor(neg_idxes, dtype=torch.long)
991
+
992
+ return {
993
+ 'text_input_ids': text_input_ids,
994
+ 'text_attn_mask': text_attn_mask,
995
+ 'gt_pos_idxes': gt_pos_idxes,
996
+ 'pos_idxes': pos_idxes,
997
+ 'pos_mask': pos_mask,
998
+ 'neg_idxes': neg_idxes,
999
+ }
1000
+
1001
+ class CorpusDataset(Dataset):
1002
+ def __init__(self, corpus_dict, tokenizer, max_len, max_n_parts):
1003
+ super().__init__()
1004
+ self.tokenizer = tokenizer
1005
+ self.max_len = max_len
1006
+ self.max_n_parts = max_n_parts
1007
+
1008
+ idx = 0
1009
+ self.corpus_list = []
1010
+ self.corpus_dict = {}
1011
+ for doc_name, articles_dict in corpus_dict.items():
1012
+ self.corpus_dict[doc_name] = {}
1013
+ for article_idx, content in articles_dict.items():
1014
+ self.corpus_list.append([doc_name, article_idx, content])
1015
+ self.corpus_dict[doc_name][article_idx] = {'content': content, 'idx': idx}
1016
+ idx += 1
1017
+
1018
+ def __len__(self):
1019
+ return len(self.corpus_list)
1020
+
1021
+ def _encode_contexts(self, idxes):
1022
+ all_input_ids, all_attn_mask = [], []
1023
+ for idx in idxes:
1024
+ name = self.corpus_list[idx][0]
1025
+ art = self.corpus_list[idx][1]
1026
+ corpus = self.corpus_dict[name][art]
1027
+ if 'content_input_ids' in corpus and 'content_attn_mask' in corpus:
1028
+ content_input_ids = corpus['content_input_ids']
1029
+ content_attn_mask = corpus['content_attn_mask']
1030
+ else:
1031
+ content = corpus['content']
1032
+ content_input_ids, content_attn_mask = tokenize_to_parts(content, self.tokenizer, self.max_len, self.max_n_parts)
1033
+ corpus['content_input_ids'] = content_input_ids
1034
+ corpus['content_attn_mask'] = content_attn_mask
1035
+
1036
+ all_input_ids.append(content_input_ids)
1037
+ all_attn_mask.append(content_attn_mask)
1038
+
1039
+ all_input_ids = torch.stack(all_input_ids)
1040
+ all_attn_mask = torch.stack(all_attn_mask)
1041
+ return all_input_ids, all_attn_mask
1042
+
1043
+ def __getitem__(self, idx):
1044
+ input_ids, attn_mask = self._encode_contexts([idx])
1045
+
1046
+ return {
1047
+ 'input_ids': input_ids,
1048
+ 'attn_mask': attn_mask,
1049
+ }
1050
+
1051
+ def _pad_batch(tensor_list, pad_value=0):
1052
+ """
1053
+ tensor_list: list of tensors, mỗi tensor shape (Nk, max_n_parts, max_len)
1054
+ return: tensor shape (B, max_Nk, max_n_parts, max_len)
1055
+ """
1056
+ max_Nk = max(t.size(0) for t in tensor_list)
1057
+
1058
+ padded = []
1059
+ for t in tensor_list:
1060
+ Nk = t.size(0)
1061
+
1062
+ if Nk < max_Nk:
1063
+ pad_shape = (max_Nk - Nk, *t.shape[1:])
1064
+ pad_tensor = t.new_full(pad_shape, pad_value)
1065
+ t = torch.cat([t, pad_tensor], dim=0)
1066
+
1067
+ padded.append(t)
1068
+
1069
+ return torch.stack(padded) # (B, max_Nk, max_n_parts, max_len)
1070
+
1071
+ def collate_fn(batch):
1072
+ text_input_ids = torch.stack([b["text_input_ids"] for b in batch])
1073
+ text_attn_mask = torch.stack([b["text_attn_mask"] for b in batch])
1074
+ gt_pos_idxes = [b["gt_pos_idxes"] for b in batch]
1075
+ neg_idxes = torch.stack([b["neg_idxes"] for b in batch])
1076
+
1077
+ pos_idxes = [b["pos_idxes"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1078
+ pos_mask = [b["pos_mask"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1079
+
1080
+ # pad theo Nk
1081
+ pos_idxes = _pad_batch(pos_idxes, pad_value=0).squeeze(-1).squeeze(-1)
1082
+ pos_mask = _pad_batch(pos_mask, pad_value=0).squeeze(-1).squeeze(-1)
1083
+
1084
+ return {
1085
+ 'text_input_ids': text_input_ids,
1086
+ 'text_attn_mask': text_attn_mask,
1087
+ 'gt_pos_idxes': gt_pos_idxes,
1088
+ 'pos_idxes': pos_idxes,
1089
+ 'pos_mask': pos_mask,
1090
+ 'neg_idxes': neg_idxes,
1091
+ }
1092
+
1093
+ # %% [code]
1094
+ def encode_corpus(state_dicts, network, corpus_loader, device):
1095
+ if torch.cuda.device_count() > 1:
1096
+ network = DataParallelProxy(network)
1097
+ network.to(device)
1098
+ network.eval()
1099
+
1100
+ all_model_embs = []
1101
+ for i, state_dict in enumerate(state_dicts):
1102
+ # ===== load model =====
1103
+ if torch.cuda.device_count() > 1:
1104
+ network.module.load_state_dict(state_dict)
1105
+ else:
1106
+ network.load_state_dict(state_dict)
1107
+
1108
+ encoded_docs = []
1109
+
1110
+ with torch.no_grad():
1111
+ for batch in corpus_loader:
1112
+ input_ids = batch['input_ids'].to(device)
1113
+ attn_mask = batch['attn_mask'].to(device)
1114
+
1115
+ encoded_doc = network.embed_doc(input_ids, attn_mask)
1116
+ if network.normalize:
1117
+ encoded_doc = F.normalize(encoded_doc, dim=-1)
1118
+
1119
+ encoded_docs.append(encoded_doc)
1120
+
1121
+ encoded_docs = torch.concat(encoded_docs, dim=0).squeeze(1) # [N, D]
1122
+ all_model_embs.append(encoded_docs)
1123
+
1124
+ # ===== ensemble =====
1125
+ # stack → [M, N, D]
1126
+ all_model_embs = torch.stack(all_model_embs, dim=0)
1127
+ final_embs = all_model_embs.mean(dim=0) # [N, D]
1128
+
1129
+ return final_embs
1130
+
1131
+ def test(state_dicts, network, test_loader, device, eval_fn, encoded_docs, topks=[5, 10, 15]):
1132
+ if torch.cuda.device_count() > 1:
1133
+ network = DataParallelProxy(network)
1134
+ network.to(device)
1135
+ network.eval()
1136
+
1137
+ per_model_scores = []
1138
+ max_k = max(topks)
1139
+
1140
+ all_scores = []
1141
+ all_logits = []
1142
+ all_gt_pos_idxes = []
1143
+ with torch.no_grad():
1144
+ for batch in test_loader:
1145
+ text_input_ids = batch['text_input_ids'].to(device)
1146
+ text_attn_mask = batch['text_attn_mask'].to(device)
1147
+ gt_pos_idxes = batch['gt_pos_idxes']
1148
+ all_gt_pos_idxes.extend(gt_pos_idxes)
1149
+
1150
+ list_encoded_texts = []
1151
+ list_logits = []
1152
+
1153
+ for state_dict in state_dicts:
1154
+ # ===== load model =====
1155
+ if torch.cuda.device_count() > 1:
1156
+ network.module.load_state_dict(state_dict)
1157
+ else:
1158
+ network.load_state_dict(state_dict)
1159
+
1160
+
1161
+ encoded_text = network.embed_query(text_input_ids, text_attn_mask)
1162
+ B, _ = encoded_text.shape
1163
+ expand_encoded_docs = encoded_docs.unsqueeze(0).expand(B, -1, -1)
1164
+ logits = network.doc_classify(encoded_text, expand_encoded_docs)
1165
+
1166
+ list_encoded_texts.append(encoded_text)
1167
+ list_logits.append(logits)
1168
+
1169
+ ensemble_encoded_text = torch.stack(list_encoded_texts, dim=0).mean(dim=0)
1170
+ ensemble_logits = torch.stack(list_logits, dim=0).mean(dim=0)
1171
+ scores = torch.matmul(ensemble_encoded_text, encoded_docs.T) # B, M
1172
+ all_scores.append(scores)
1173
+ all_logits.append(ensemble_logits)
1174
+
1175
+ all_scores = torch.concat(all_scores, dim=0) # N, M
1176
+ all_logits = torch.concat(all_logits, dim=0) # N, M
1177
+
1178
+ _, topk_indices = torch.topk(all_scores, k=10, dim=-1) # [B, K]
1179
+ topk_logits = all_logits.gather(1, topk_indices) # [B, K]
1180
+ pred_topk_full = [idx[logit > 0].tolist() for idx, logit in zip(topk_indices, topk_logits)]
1181
+
1182
+ pred_topk_full = list_to_tuple(pred_topk_full)
1183
+ all_gt_pos_idxes = list_to_tuple(all_gt_pos_idxes)
1184
+
1185
+ final_score = {}
1186
+ for k in topks:
1187
+ pred_topk_k = [p[:k] for p in pred_topk_full]
1188
+ final_score[k] = eval_fn(pred_topk_k, all_gt_pos_idxes)
1189
+
1190
+ return final_score
1191
+
1192
+ # %% [code]
1193
+ with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f:
1194
+ data_train = json.load(f)
1195
+
1196
+ with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f:
1197
+ data_test = json.load(f)
1198
+
1199
+ with open(f'{test_dir}/corpus.json', "r", encoding="utf-8") as f:
1200
+ data_corpus = json.load(f)
1201
+
1202
+ print('Train:', len(data_train))
1203
+ print('Test:', len(data_test))
1204
+ print('Corpus:', len(data_corpus))
1205
+
1206
+ # %% [code]
1207
+ # trigger_types = sorted(list(set([e['label'] for d in data_train + data_test for e in d['issues']]))) # NBR : Neighbor relation
1208
+ # bio_trigger_types = ['O'] + [f'{prefix}-{trg}' for trg in trigger_types for prefix in ['B', 'I']]
1209
+ # trigger_label2id = {l: i for i, l in enumerate(bio_trigger_types)}
1210
+ # trigger_id2label = {i: l for l, i in trigger_label2id.items()}
1211
+
1212
+ # argument_types = sorted(list(set([a['role'] for d in data_train + data_test for e in d['issues'] for a in e['arguments']])))
1213
+ # bio_argument_types = ['O'] + [f'{prefix}-{arg}' for arg in argument_types for prefix in ['B', 'I']]
1214
+ # argument_label2id = {l: i for i, l in enumerate(bio_argument_types)}
1215
+ # argument_id2label = {i: l for l, i in argument_label2id.items()}
1216
+
1217
+ # label2id = {
1218
+ # 'Trg': trigger_label2id,
1219
+ # 'Arg': argument_label2id,
1220
+ # }
1221
+
1222
+ # id2label = {
1223
+ # 'Trg': trigger_id2label,
1224
+ # 'Arg': argument_id2label,
1225
+ # }
1226
+
1227
+ # %% [code]
1228
+ # zero_events_idxes = []
1229
+ # for idx, d in enumerate(data_train):
1230
+ # if len(d['issues']) == 0:
1231
+ # zero_events_idxes.append(idx)
1232
+
1233
+ # n_zero_events_samples = len(zero_events_idxes)
1234
+ # n_has_events_samples = len(data_train) - n_zero_events_samples
1235
+
1236
+ # random.seed(42)
1237
+ # k = min(int(n_has_events_samples * zero_events_rate), len(zero_events_idxes))
1238
+ # sampled_zero_events_idxes = random.sample(zero_events_idxes, k)
1239
+
1240
+ # new_data_train = []
1241
+ # for idx, d in enumerate(data_train):
1242
+ # if len(d['issues']) == 0:
1243
+ # if idx in sampled_zero_events_idxes:
1244
+ # new_data_train.append(d)
1245
+ # else:
1246
+ # new_data_train.append(d)
1247
+ # data_train = new_data_train
1248
+
1249
+ # print('Train:', len(data_train))
1250
+
1251
+ # %% [code]
1252
+ if debug_only:
1253
+ data_train = data_train[:20]
1254
+ data_test = data_test[:20]
1255
+
1256
+ print('Train:', len(data_train))
1257
+ print('Test:', len(data_test))
1258
+
1259
+ # %% [code]
1260
+ tokenizer = AutoTokenizer.from_pretrained(backbone_model_name)
1261
+
1262
+ # %% [code]
1263
+ print('Experiment name:', state_dict_save_name)
1264
+
1265
+ # %% [code]
1266
+ if not test_only:
1267
+ full_idxes = np.array(range(len(data_train)))
1268
+ training_logs, best_models, last_models = [], [], []
1269
+ start_training_time = time.time()
1270
+ for seed in SEEDS:
1271
+ kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed)
1272
+ generator = torch.Generator()
1273
+ generator.manual_seed(seed)
1274
+
1275
+ corpusset = CorpusDataset(data_corpus, tokenizer, **corpus_memory_params)
1276
+ corpus_loader = DataLoader(corpusset, generator=generator, **val_loader_params)
1277
+ for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)):
1278
+ if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx:
1279
+ continue
1280
+ set_seed(seed)
1281
+
1282
+ train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx]
1283
+
1284
+ trainset = LawRetrievalDataset(data_train, train_idxes, data_corpus, tokenizer, **train_memory_params)
1285
+ valset = LawRetrievalDataset(data_train, val_idxes, data_corpus, tokenizer, **val_memory_params)
1286
+
1287
+ train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params)
1288
+ val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params)
1289
+
1290
+ my_model = EncodeModel(
1291
+ **model_params
1292
+ )
1293
+ total_params = sum(p.numel() for p in my_model.parameters())
1294
+ print(f"Total params: {total_params:,}")
1295
+
1296
+ # optimizer, scheduler = configure_optimizers(my_model, optim_params, scheduler_params)
1297
+ encoder_params = set(map(id, my_model.encoder.parameters()))
1298
+ other_params = [
1299
+ p for p in my_model.parameters()
1300
+ if id(p) not in encoder_params
1301
+ ]
1302
+ optimizer = optim.AdamW([
1303
+ {"params": my_model.encoder.parameters(), "lr": 2e-5},
1304
+ {"params": other_params}
1305
+ ], lr=5e-4)
1306
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
1307
+
1308
+ loss_fn = CustomLoss(
1309
+ **loss_func_params
1310
+ )
1311
+ eval_fn = CustomEvalFn(**eval_func_params)
1312
+ trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}'
1313
+ trainer = Trainer(**trainer_params)
1314
+
1315
+ print(f'Start Training Fold {fold_idx}...')
1316
+ training_log, best_model, last_model = trainer.fit(
1317
+ my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, corpus_loader, eval_fn,
1318
+ start_epoch=1, start_training_time=start_training_time, refresh_every=2,
1319
+ )
1320
+
1321
+ training_logs.append(training_log)
1322
+ best_models.append(best_model)
1323
+ last_models.append(last_model)
1324
+
1325
+ # %% [code]
1326
+ def load_all_state_dicts(folder):
1327
+ files = []
1328
+
1329
+ for file in os.listdir(folder):
1330
+ if file.endswith(".pt") or file.endswith(".pth"):
1331
+ m = re.search(r"f(\d+)", file) # tìm f<số>
1332
+ if m:
1333
+ fold = int(m.group(1))
1334
+ files.append((fold, file))
1335
+
1336
+ # sort theo fold
1337
+ files.sort(key=lambda x: x[0])
1338
+
1339
+ state_dicts = []
1340
+ for fold, file in files:
1341
+ path = os.path.join(folder, file)
1342
+ print(f"Loading fold {fold}: {file}")
1343
+ state_dict = torch.load(path, map_location="cpu")
1344
+ state_dicts.append(state_dict)
1345
+
1346
+ return state_dicts
1347
+
1348
+ if test_only:
1349
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"])
1350
+ get_ipython().system('rm -rf .cache .gitattributes')
1351
+
1352
+ best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s")
1353
+ last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts")
1354
+
1355
+ # %% [code]
1356
+ os.makedirs(f'{checkpoints_dir}/results', exist_ok=True)
1357
+ testset = LawRetrievalDataset(data_test, range(len(data_test)), data_corpus, tokenizer, **val_memory_params)
1358
+ generator = torch.Generator()
1359
+ test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params)
1360
+ eval_fn = CustomEvalFn(**eval_func_params)
1361
+ my_model = EncodeModel(
1362
+ **model_params
1363
+ )
1364
+ total_params = sum(p.numel() for p in my_model.parameters())
1365
+ print(f"Total params: {total_params:,}")
1366
+
1367
+ # %% [code]
1368
+ start_time = time.time()
1369
+ encoded_docs = encode_corpus(best_models, my_model, corpus_loader, device)
1370
+ best_score = test(best_models, my_model, test_loader, device, eval_fn, encoded_docs)
1371
+
1372
+ encoded_docs = encode_corpus(last_models, my_model, corpus_loader, device)
1373
+ last_score = test(last_models, my_model, test_loader, device, eval_fn, encoded_docs)
1374
+
1375
+ result_test = {"Best model": best_score, "Last model": last_score}
1376
+
1377
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f:
1378
+ json.dump(result_test, f, ensure_ascii=False, indent=2)
1379
+
1380
+ print('Test:', time.time() - start_time, 's --> Done!')
1381
+
1382
+ # %% [code]
1383
+ def dict_to_df(data):
1384
+ row_tuples = []
1385
+ row_values = []
1386
+
1387
+ # ===== lấy model đầu tiên =====
1388
+ first_model = next(iter(data.values()))
1389
+
1390
+ # ===== eval keys =====
1391
+ eval_keys = list(first_model.keys())
1392
+
1393
+ # ===== tự lấy metrics =====
1394
+ first_eval = next(iter(first_model.values()))
1395
+ metrics = list(first_eval.keys())
1396
+
1397
+ for eval_key in eval_keys:
1398
+ row_tuples.append(eval_key)
1399
+
1400
+ row = {}
1401
+
1402
+ for model_name, model_data in data.items():
1403
+ for metric in metrics:
1404
+ row[(model_name, metric)] = model_data[eval_key][metric]
1405
+
1406
+ row_values.append(row)
1407
+
1408
+ # ===== DataFrame =====
1409
+ df = pd.DataFrame(row_values)
1410
+
1411
+ # ===== MultiIndex columns =====
1412
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
1413
+
1414
+ # ===== Index =====
1415
+ df.index = pd.Index(
1416
+ row_tuples,
1417
+ name="evaluation"
1418
+ )
1419
+
1420
+ # ===== Sort =====
1421
+ sort_keys = []
1422
+
1423
+ for model_name in data.keys():
1424
+ for metric in ["f1", "f2", "mrr", "recall", "precision"]:
1425
+ key = (model_name, metric)
1426
+
1427
+ if key in df.columns:
1428
+ sort_keys.append(key)
1429
+
1430
+ if sort_keys:
1431
+ df = df.sort_values(
1432
+ by=sort_keys,
1433
+ ascending=False
1434
+ )
1435
+
1436
+ return df
1437
+
1438
+ result_test_df = dict_to_df(result_test)
1439
+ result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx")
1440
+ result_test_df
1441
+
1442
+ # %% [code]
1443
+ key = ("Best model", "f2")
1444
+ result_test_df_best = result_test_df.sort_values(by=key, ascending=False).groupby(level="evaluation").head(1)
1445
+ result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx")
1446
+ result_test_df_best
1447
+
1448
+ # %% [code]
1449
+ def get_avg_best_score(logs):
1450
+ return float(np.mean([list(log.values())[-1]['best_score'] for log in logs]))
1451
+
1452
+ def get_avg_log(logs, epochs):
1453
+ avg_log = {}
1454
+
1455
+ for epoch in range(1, epochs + 1):
1456
+ val_score = 0.0
1457
+ train_loss = 0.0
1458
+ n_eval = 0
1459
+
1460
+ for idx in range(len(logs)):
1461
+ log = logs[idx].get(epoch, logs[idx].get(str(epoch)))
1462
+ if log is None:
1463
+ continue
1464
+
1465
+ val_score += log.get('val_score', 0.0)
1466
+ train_loss += log.get('train_loss', 0.0)
1467
+ n_eval += 1
1468
+
1469
+ if n_eval == 0:
1470
+ continue
1471
+
1472
+ avg_log[epoch] = {
1473
+ 'train_loss': train_loss / n_eval,
1474
+ 'val_score': val_score / n_eval if val_score != 0 else float('inf')
1475
+ }
1476
+
1477
+ return avg_log
1478
+
1479
+ def parse_label_key(label: str):
1480
+ try:
1481
+ first = float(label.split('_', 1)[0]) # số đầu: trước dấu _
1482
+ last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0])
1483
+ return first, last
1484
+ except:
1485
+ return (0, 0)
1486
+
1487
+ def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)):
1488
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
1489
+
1490
+ # ===== Plot Train Loss =====
1491
+ for name, log in logs_dict.items():
1492
+ epochs = sorted(log.keys())
1493
+ train_loss = [log[e]['train_loss'] for e in epochs]
1494
+ axes[0].plot(epochs, train_loss, label=name)
1495
+
1496
+ axes[0].set_xlabel('Epoch')
1497
+ axes[0].set_ylabel('Train Loss')
1498
+ axes[0].set_title('Training Loss')
1499
+ axes[0].grid(True)
1500
+
1501
+ # ===== Plot Validation Score =====
1502
+ for name, log in logs_dict.items():
1503
+ epochs = sorted(log.keys())
1504
+ val_score = [log[e]['val_score'] for e in epochs]
1505
+ axes[1].plot(epochs, val_score, label=name)
1506
+
1507
+ axes[1].set_xlabel('Epoch')
1508
+ axes[1].set_ylabel('Validation Score')
1509
+ axes[1].set_title('Validation Score')
1510
+ axes[1].grid(True)
1511
+
1512
+ # ===== Shared Legend =====
1513
+ handles, labels = axes[0].get_legend_handles_labels()
1514
+ pairs = list(zip(handles, labels))
1515
+ pairs_sorted = sorted(
1516
+ pairs,
1517
+ key=lambda x: parse_label_key(x[1])
1518
+ )
1519
+ handles_sorted, labels_sorted = zip(*pairs_sorted)
1520
+
1521
+ axes[0].legend(
1522
+ handles_sorted,
1523
+ labels_sorted,
1524
+ loc='center left',
1525
+ bbox_to_anchor=(1.01, 0.5),
1526
+ borderaxespad=0.
1527
+ )
1528
+
1529
+ plt.tight_layout(rect=[0, 0, 1, 1])
1530
+
1531
+ if save_path is not None:
1532
+ os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None
1533
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
1534
+
1535
+ plt.show()
1536
+
1537
+ # %% [code]
1538
+ if not test_only:
1539
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*lr*.json"], ignore_patterns=[])
1540
+ get_ipython().system('rm -rf .cache .gitattributes')
1541
+
1542
+ # %% [code]
1543
+ if not test_only:
1544
+ experiments = {}
1545
+ for experiment in os.listdir(pretrained_dir):
1546
+ experiment_logs = []
1547
+ try:
1548
+ for seed in SEEDS:
1549
+ for fold_idx in range(nfolds):
1550
+ with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f:
1551
+ experiment_log = json.load(f)
1552
+ experiment_logs.append(experiment_log)
1553
+ except:
1554
+ pass
1555
+ experiments[experiment] = get_avg_log(experiment_logs, 1000)
1556
+ experiments[state_dict_save_name] = get_avg_log(training_logs, 1000)
1557
+
1558
+ # %% [code]
1559
+ if not test_only:
1560
+ score = get_avg_best_score(training_logs)
1561
+ state_dict_save_name, score
1562
+
1563
+ # %% [code]
1564
+ if not test_only:
1565
+ plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5))
1566
+
2_lr_new_structure_3/lasts/2_lr_new_structure_3_s26092004_f0_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:643f356716dd0e2760894547685c981a61810796565a137085651020b3d1047c
3
+ size 543744384
2_lr_new_structure_3/logs/2_lr_new_structure_3_log_plot.jpg ADDED

Git LFS Details

  • SHA256: 288fcfb57e3bc7b436dcb2654545fdffed8d00bb6ec06572c254e75e44c68e19
  • Pointer size: 131 Bytes
  • Size of remote file: 435 kB
2_lr_new_structure_3/logs/2_lr_new_structure_3_s26092004_f0_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 0.0005], "train_loss": 5.891900539398193, "total": 5.891900410221572, "contrastive_loss": 4.222291672110159, "triplet_loss": 0.25020903010033446, "loss_bce": 0.1433428250826322}, "2": {"lr": [1.988303923565381e-05, 0.0004969282409784868], "train_loss": 5.549773693084717, "total": 5.549773659594481, "contrastive_loss": 3.921339360367893, "triplet_loss": 0.24895484949832775, "loss_bce": 0.12804412841796875}, "3": {"lr": [1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 5.3650312423706055, "total": 5.365031060566472, "contrastive_loss": 3.622427949937291, "triplet_loss": 0.2648411371237458, "loss_bce": 0.13025954893998876}, "4": {"lr": [1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 5.099754333496094, "total": 5.099754307979724, "contrastive_loss": 3.425837181882316, "triplet_loss": 0.25668896321070234, "loss_bce": 0.1219835185685684}, "5": {"lr": [1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 5.0689697265625, "total": 5.0689697265625, "contrastive_loss": 3.3710027075930182, "triplet_loss": 0.2614966555183946, "loss_bce": 0.12091292824633544}, "6": {"lr": [1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 4.85378360748291, "total": 4.853783853077968, "contrastive_loss": 3.2030066040447323, "triplet_loss": 0.2556438127090301, "loss_bce": 0.11649164308273673}, "7": {"lr": [1.60839598967785e-05, 0.00039715242044697206], "train_loss": 4.8149566650390625, "total": 4.814956511940845, "contrastive_loss": 3.10197892715301, "triplet_loss": 0.26714046822742477, "loss_bce": 0.11678731401628475}, "8": {"lr": [1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 4.540054798126221, "total": 4.540054984714674, "contrastive_loss": 2.9078405884197323, "triplet_loss": 0.256061872909699, "loss_bce": 0.11249140353505827, "val_score": 0.02790384188955116, "best_score": 0.02790384188955116, "new_best_model": true, "precision": 0.015286210317460318, "recall": 0.05818452380952381, "f2": 0.02790384188955116, "mrr": 0.05385491071428571}, "9": {"lr": [1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 4.362818717956543, "total": 4.362818689250627, "contrastive_loss": 2.770563540251359, "triplet_loss": 0.2516722408026756, "loss_bce": 0.10969526871390965, "val_score": 0.02735548993960554, "best_score": 0.02790384188955116, "new_best_model": false, "precision": 0.0158051421957672, "recall": 0.05901785714285714, "f2": 0.02735548993960554, "mrr": 0.05484945436507937}, "10": {"lr": [1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 4.209826469421387, "total": 4.209826619330059, "contrastive_loss": 2.650550459539611, "triplet_loss": 0.24853678929765885, "loss_bce": 0.1072417677046862, "val_score": 0.02780185949317252, "best_score": 0.02790384188955116, "new_best_model": false, "precision": 0.016415757275132285, "recall": 0.05755952380952381, "f2": 0.02780185949317252, "mrr": 0.053689980158730166}, "11": {"lr": [1.0500000000000003e-05, 0.0002505], "train_loss": 4.083454608917236, "total": 4.083454859296614, "contrastive_loss": 2.552323051120924, "triplet_loss": 0.24456521739130435, "loss_bce": 0.10527374752389149, "val_score": 0.028257527071514604, "best_score": 0.028257527071514604, "new_best_model": true, "precision": 0.018149553571428584, "recall": 0.05630952380952381, "f2": 0.028257527071514604, "mrr": 0.05316369047619048}, "12": {"lr": [9.013872582117811e-06, 0.00021146960097246246], "train_loss": 3.970719337463379, "total": 3.9707194554765888, "contrastive_loss": 2.4660395491482023, "triplet_loss": 0.24101170568561872, "loss_bce": 0.10363423146531733, "val_score": 0.029152937006498292, "best_score": 0.029152937006498292, "new_best_model": true, "precision": 0.020628637566137573, "recall": 0.05610119047619047, "f2": 0.029152937006498292, "mrr": 0.05351438492063493}, "13": {"lr": [7.564338553438001e-06, 0.00017340025990345064], "train_loss": 3.8807168006896973, "total": 3.880716891591764, "contrastive_loss": 2.3979204362850126, "triplet_loss": 0.23683110367892976, "loss_bce": 0.10228407741789036, "val_score": 0.030467162896573814, "best_score": 0.030467162896573814, "new_best_model": true, "precision": 0.022719494047619044, "recall": 0.05526785714285714, "f2": 0.030467162896573814, "mrr": 0.05323660714285715}, "14": {"lr": [6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 3.819249391555786, "total": 3.8192493859740804, "contrastive_loss": 2.3496944976092182, "triplet_loss": 0.23411371237458195, "loss_bce": 0.10161845101959331, "val_score": 0.03235597141888681, "best_score": 0.03235597141888681, "new_best_model": true, "precision": 0.02615203373015872, "recall": 0.05574404761904762, "f2": 0.03235597141888681, "mrr": 0.053869047619047615}, "15": {"lr": [4.916040103221507e-06, 0.00010384757955302797], "train_loss": 3.776521921157837, "total": 3.776521918765677, "contrastive_loss": 2.317281155283236, "triplet_loss": 0.2334866220735786, "loss_bce": 0.10143355302587401, "val_score": 0.0334980782114911, "best_score": 0.0334980782114911, "new_best_model": true, "precision": 0.028797619047619048, "recall": 0.05574404761904762, "f2": 0.0334980782114911, "mrr": 0.052983630952380956}, "16": {"lr": [3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 3.759141683578491, "total": 3.7591418007943145, "contrastive_loss": 2.2985774521843645, "triplet_loss": 0.23327759197324416, "loss_bce": 0.1015322758601262, "val_score": 0.03338747714355853, "best_score": 0.0334980782114911, "new_best_model": false, "precision": 0.03056944444444444, "recall": 0.054077380952380946, "f2": 0.03338747714355853, "mrr": 0.05168154761904762}, "17": {"lr": [2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 3.768730640411377, "total": 3.7687306483852425, "contrastive_loss": 2.301013959291388, "triplet_loss": 0.23494983277591974, "loss_bce": 0.10186893884154866, "val_score": 0.03286704195956774, "best_score": 0.0334980782114911, "new_best_model": false, "precision": 0.03215773809523809, "recall": 0.05178571428571429, "f2": 0.03286704195956774, "mrr": 0.04980654761904762}, "18": {"lr": [2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 3.7879393100738525, "total": 3.7879393714726173, "contrastive_loss": 2.309556635725857, "triplet_loss": 0.23745819397993312, "loss_bce": 0.10264405598209853, "val_score": 0.03207755656543768, "best_score": 0.0334980782114911, "new_best_model": false, "precision": 0.034419642857142864, "recall": 0.05053571428571429, "f2": 0.03207755656543768, "mrr": 0.048869047619047624}}
2_lr_new_structure_3/r1s/2_lr_new_structure_3_s26092004_f0_r1_vs0.03350_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3378f252656a2bccc86e56eedfedaab2dbd87c2df81eb2fc334ab7b5b90d523
3
+ size 543746088
2_lr_new_structure_3/results/2_lr_new_structure_3_test.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Best model": {
3
+ "5": {
4
+ "precision": 0.004813002008032126,
5
+ "recall": 0.012801204819277108,
6
+ "f2": 0.00504633239588152,
7
+ "mrr": 0.009801706827309237
8
+ },
9
+ "10": {
10
+ "precision": 0.0043710556511761325,
11
+ "recall": 0.01355421686746988,
12
+ "f2": 0.004826933173779982,
13
+ "mrr": 0.009902108433734938
14
+ },
15
+ "15": {
16
+ "precision": 0.0043710556511761325,
17
+ "recall": 0.01355421686746988,
18
+ "f2": 0.004826933173779982,
19
+ "mrr": 0.009902108433734938
20
+ }
21
+ },
22
+ "Last model": {
23
+ "5": {
24
+ "precision": 0.010724146586345377,
25
+ "recall": 0.028237951807228916,
26
+ "f2": 0.009614348087346155,
27
+ "mrr": 0.02331827309236948
28
+ },
29
+ "10": {
30
+ "precision": 0.009912118712947025,
31
+ "recall": 0.03049698795180723,
32
+ "f2": 0.009648437895541064,
33
+ "mrr": 0.02366788582903041
34
+ },
35
+ "15": {
36
+ "precision": 0.009912118712947025,
37
+ "recall": 0.03049698795180723,
38
+ "f2": 0.009648437895541064,
39
+ "mrr": 0.02366788582903041
40
+ }
41
+ }
42
+ }
2_lr_new_structure_3/results/2_lr_new_structure_3_test_df.xlsx ADDED
Binary file (5.4 kB). View file
 
2_lr_new_structure_3/results/2_lr_new_structure_3_test_df_best.xlsx ADDED
Binary file (5.4 kB). View file