SS3M commited on
Commit
23010f5
·
verified ·
1 Parent(s): d4f59f9

Upload 5_lr_up_lr_6's state dict

Browse files
.gitattributes CHANGED
@@ -68,3 +68,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
68
  1_lr_add_info_5/logs/1_lr_add_info_5_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
69
  6_4_actions_8/logs/6_4_actions_8_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
70
  8_entities_add_weight_keep_2_learn_all_10/logs/8_entities_add_weight_keep_2_learn_all_10_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
 
 
68
  1_lr_add_info_5/logs/1_lr_add_info_5_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
69
  6_4_actions_8/logs/6_4_actions_8_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
70
  8_entities_add_weight_keep_2_learn_all_10/logs/8_entities_add_weight_keep_2_learn_all_10_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
71
+ 5_lr_up_lr_6/logs/5_lr_up_lr_6_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
5_lr_up_lr_6/5_lr_up_lr_6.py ADDED
@@ -0,0 +1,1558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [code]
2
+ get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch]')
3
+
4
+ # %% [code]
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
12
+ import torch.nn.functional as F
13
+ import albumentations as albu
14
+ from transformers import AutoTokenizer, AutoModel
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from positional_encodings.torch_encodings import PositionalEncoding1D
18
+
19
+ from sklearn.metrics import f1_score
20
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
21
+ from scipy.spatial.transform import Rotation as R
22
+ from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold
23
+ from sklearn.metrics import precision_recall_fscore_support
24
+ from timm.utils import ModelEmaV3
25
+ import timm
26
+
27
+ import os
28
+ import gc
29
+ import json
30
+ from pathlib import Path
31
+ import pickle
32
+ from tqdm.auto import tqdm
33
+ import copy
34
+ import numpy as np
35
+ import pandas as pd
36
+ import polars as pl
37
+ from PIL import Image
38
+ import time
39
+ from tqdm import tqdm
40
+ from matplotlib import pyplot as plt
41
+ import seaborn as sns
42
+ from multiprocessing import Manager as MemoryManager
43
+ from functools import lru_cache
44
+ import shutil
45
+ import glob
46
+ import cv2
47
+ import random
48
+ import re
49
+ import joblib
50
+ import math
51
+ from huggingface_hub import HfApi, snapshot_download
52
+ import evaluate
53
+ from underthesea import word_tokenize as vi_tokenize_tool
54
+ import spacy
55
+ en_tokenize_tool = spacy.load("en_core_web_sm")
56
+ from collections import defaultdict, Counter
57
+
58
+ # %% [code]
59
+ # Global config
60
+ SEEDS = [26092004]
61
+ topk = 1
62
+ nfolds = 5
63
+ only_fold_idx = 0
64
+ test_only = 0
65
+ debug_only = 0
66
+
67
+ # Config thư mục
68
+ dataset = 'kltn/raw' # vhe, bkee, casie, kltn/only_issues, kltn/only_actions, kltn/raw
69
+ root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' ## Thư mục chứa file train, val, test
70
+ train_dir = f'{root_dir}'
71
+ # val_dir = f'{root_dir}/val'
72
+ test_dir = f'{root_dir}'
73
+
74
+ # Config checkpoints
75
+
76
+ # Config training
77
+ epochs = 18 if not debug_only else 2
78
+ batch_size = 32
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+ # # Thêm biến toàn cục nào đó vào đây
81
+ repo_name = 'SS3M/kltn-experiments'
82
+ state_dict_save_name = "5_lr_up_lr_6"
83
+ checkpoints_dir = state_dict_save_name
84
+ pretrained_dir = "/kaggle/working"
85
+ os.makedirs(f'{checkpoints_dir}', exist_ok=True)
86
+
87
+ backbone_model_name = "bert-base-uncased" if dataset == "casie" else "vinai/phobert-base"
88
+ word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == "casie" else vi_tokenize_tool(text)
89
+ max_len_dict = {
90
+ 'kltn/raw': 256,
91
+ 'kltn/only_issues': 52,
92
+ 'kltn/only_actions': 69,
93
+ 'vhe': 51,
94
+ 'bkee': 62,
95
+ 'casie': 40,
96
+ }
97
+ zero_events_rate_dict = {
98
+ 'kltn/raw': 1000,
99
+ 'kltn/only_issues': 0,
100
+ 'kltn/only_actions': 0.2,
101
+ 'vhe': 1000, # mean keep all zero-events samples
102
+ 'bkee': 1000,
103
+ 'casie': 1000,
104
+ }
105
+
106
+ max_len = max_len_dict[dataset]
107
+ max_n_parts = 2
108
+ max_span_len = 14
109
+ n_negs = 5 * 20
110
+ zero_events_rate = zero_events_rate_dict[dataset]
111
+
112
+ # Trainer
113
+ trainer_params = {
114
+ "training_time": "00:11:30:00",
115
+ "eval_mode": "max",
116
+ "topk": topk,
117
+ "save_name": state_dict_save_name,
118
+ "save_best": True,
119
+ "save_last": True,
120
+ "device": device,
121
+ "logging": True,
122
+ "logging_file": True,
123
+ "checkpoints_dir": checkpoints_dir,
124
+ "early_stopping": 30,
125
+ "eval_from_ratio": 0.4,
126
+ "eval_every": 1,
127
+ "schedule_in_step": False,
128
+ "use_ema": True,
129
+ "ema_from_ratio": 0.3,
130
+ "ema_decay": 0.9995,
131
+ "max_grad_norm": 200.0,
132
+ "return_best": True,
133
+ "return_last": True,
134
+ }
135
+
136
+ # Memory
137
+ train_memory_params = {
138
+ 'max_len': max_len,
139
+ 'max_n_parts': max_n_parts,
140
+ 'n_negs': n_negs,
141
+ }
142
+ val_memory_params = {
143
+ 'max_len': max_len,
144
+ 'max_n_parts': max_n_parts,
145
+ 'n_negs': n_negs,
146
+ }
147
+ corpus_memory_params = {
148
+ 'max_len': max_len,
149
+ 'max_n_parts': max_n_parts,
150
+ }
151
+
152
+ # Data Loader
153
+ def seed_worker(worker_id):
154
+ worker_seed = torch.initial_seed() % 2**32
155
+ np.random.seed(worker_seed)
156
+ random.seed(worker_seed)
157
+
158
+ train_loader_params = {
159
+ 'batch_size': batch_size,
160
+ 'shuffle': True,
161
+ 'pin_memory':True,
162
+ 'num_workers': 2,
163
+ 'drop_last': False,
164
+ 'worker_init_fn': seed_worker,
165
+ 'persistent_workers': False,
166
+ }
167
+ val_loader_params = {
168
+ 'batch_size': batch_size,
169
+ 'shuffle': False,
170
+ 'pin_memory':True,
171
+ 'num_workers': 1,
172
+ 'drop_last': False,
173
+ 'worker_init_fn': seed_worker,
174
+ 'persistent_workers': False,
175
+ }
176
+
177
+ # Model
178
+ model_params = {
179
+ 'backbone_name': backbone_model_name,
180
+ 'projection_dim': 256,
181
+ 'normalize': True,
182
+ }
183
+
184
+ # Loss Func
185
+ loss_func_params = {
186
+ 'lambda_contrastive': 1.0,
187
+ 'lambda_triplet': 5.0,
188
+ }
189
+ eval_func_params = {}
190
+
191
+ # Optim
192
+ optim_params = {
193
+ 'name': 'AdamW',
194
+ 'lr': 1e-4,
195
+ 'weight_decay': 1e-4,
196
+ }
197
+ scheduler_params = {
198
+ 'name': 'CosineAnnealingLR',
199
+ 'T_max': 20, # Số epoch để hoàn thành một chu kỳ giảm LR
200
+ 'eta_min': 1e-6 # Learning rate nhỏ nhất trong chu kỳ
201
+ }
202
+
203
+ # %% [code]
204
+ def set_seed(seed=42):
205
+ random.seed(seed)
206
+ np.random.seed(seed)
207
+ torch.manual_seed(seed)
208
+ torch.cuda.manual_seed(seed)
209
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
210
+ torch.use_deterministic_algorithms(False)
211
+ torch.backends.cudnn.deterministic = True
212
+ torch.backends.cudnn.benchmark = False
213
+ os.environ['PYTHONHASHSEED'] = str(seed)
214
+
215
+ # %% [code]
216
+ class CustomLoss(nn.Module):
217
+ def __init__(
218
+ self,
219
+ temperature=0.05,
220
+ margin=0.2,
221
+ lambda_contrastive=1.0,
222
+ lambda_triplet=0.5,
223
+ ):
224
+ super().__init__()
225
+
226
+ self.temperature = temperature
227
+ self.margin = margin
228
+
229
+ self.lambda_contrastive = lambda_contrastive
230
+ self.lambda_triplet = lambda_triplet
231
+
232
+ def forward(
233
+ self,
234
+ encoded_text, encoded_pos, encoded_neg, pos_mask,
235
+ ):
236
+ loss_contrastive = self.multi_pos_contrastive_loss(encoded_text, encoded_pos, encoded_neg, pos_mask)
237
+ loss_triplet = self.hardest_triplet_loss(encoded_text, encoded_pos, encoded_neg, pos_mask)
238
+
239
+ total_loss = (
240
+ self.lambda_contrastive * loss_contrastive +
241
+ self.lambda_triplet * loss_triplet
242
+ )
243
+
244
+ return {
245
+ "total": total_loss,
246
+ "contrastive_loss": loss_contrastive,
247
+ "triplet_loss": loss_triplet,
248
+ }
249
+
250
+ def multi_pos_contrastive_loss(self, q, pos, neg, pos_mask):
251
+ B, P, D = pos.shape
252
+ N = neg.shape[1]
253
+
254
+ # ===== concat docs =====
255
+ docs = torch.cat([pos, neg], dim=1) # [B, P+N, D]
256
+
257
+ # ===== similarity =====
258
+ logits = torch.matmul(q.unsqueeze(1), docs.transpose(1, 2)).squeeze(1)
259
+ logits = logits / self.temperature # [B, P+N]
260
+
261
+ # ===== labels =====
262
+ labels = torch.zeros_like(logits)
263
+ labels[:, :P] = pos_mask # chỉ pos hợp lệ
264
+
265
+ # ===== log-softmax =====
266
+ log_prob = logits - torch.logsumexp(logits, dim=1, keepdim=True)
267
+
268
+ # ===== normalize theo số pos thật =====
269
+ pos_count = pos_mask.sum(dim=1).clamp(min=1)
270
+
271
+ loss = -(labels * log_prob).sum(dim=1) / pos_count
272
+
273
+ return loss.mean()
274
+
275
+ def hardest_triplet_loss(self, q, pos, neg, pos_mask):
276
+ # ===== similarity =====
277
+ pos_sim = torch.matmul(q.unsqueeze(1), pos.transpose(1, 2)).squeeze(1) # [B, P]
278
+ neg_sim = torch.matmul(q.unsqueeze(1), neg.transpose(1, 2)).squeeze(1) # [B, N]
279
+
280
+ # ===== mask pos =====
281
+ pos_sim_masked = pos_sim.clone()
282
+ pos_sim_masked[pos_mask == 0] = float('inf') # loại pad
283
+
284
+ # ===== hardest =====
285
+ hardest_pos = pos_sim_masked.min(dim=1).values
286
+ hardest_neg = neg_sim.max(dim=1).values
287
+
288
+ # ===== loss =====
289
+ loss = F.relu(self.margin + hardest_neg - hardest_pos)
290
+
291
+ return loss.mean()
292
+
293
+ # %% [code]
294
+ class CustomEvalFn(nn.Module):
295
+ def __init__(self):
296
+ super().__init__()
297
+
298
+ def forward(self, pred_topk, real_topk):
299
+ """
300
+ pred_topk: List[List[int]] shape [B, K]
301
+ real_topk: List[List[int]] shape [B, Ki]
302
+ """
303
+
304
+ B = len(pred_topk)
305
+
306
+ total_recall = 0.0
307
+ total_precision = 0.0
308
+ total_mrr = 0.0
309
+ total_f2 = 0.0
310
+
311
+ for i in range(B):
312
+ preds = pred_topk[i]
313
+ pred_set = set(preds)
314
+
315
+ gts = set(real_topk[i])
316
+
317
+ # ===== Recall@K =====
318
+ hit = any(p in gts for p in preds)
319
+ total_recall += 1.0 if hit else 0.0
320
+
321
+ # ===== MRR =====
322
+ rr = 0.0
323
+ for rank, p in enumerate(preds, start=1):
324
+ if p in gts:
325
+ rr = 1.0 / rank
326
+ break
327
+
328
+ total_mrr += rr
329
+
330
+ # ===== Precision / Recall / F2 =====
331
+ tp = len(pred_set & gts)
332
+
333
+ precision = tp / len(pred_set) if len(pred_set) > 0 else 0.0
334
+ recall_f = tp / len(gts) if len(gts) > 0 else 0.0
335
+
336
+ total_precision += precision
337
+
338
+ beta2 = 2 ** 2
339
+
340
+ if precision + recall_f > 0:
341
+ f2 = (1 + beta2) * precision * recall_f / (
342
+ beta2 * precision + recall_f
343
+ )
344
+ else:
345
+ f2 = 0.0
346
+
347
+ total_f2 += f2
348
+
349
+ recall = total_recall / B
350
+ precision = total_precision / B
351
+ mrr = total_mrr / B
352
+ f2 = total_f2 / B
353
+
354
+ return {
355
+ "precision": precision,
356
+ "recall": recall,
357
+ "f2": f2,
358
+ "mrr": mrr,
359
+ }
360
+
361
+ # %% [code]
362
+ class MLP(nn.Module):
363
+ def __init__(self, in_size, hid_size, out_size):
364
+ super().__init__()
365
+ self.mlp = nn.Sequential(
366
+ nn.Linear(in_size, hid_size),
367
+ nn.ReLU(),
368
+ nn.Linear(hid_size, out_size)
369
+ )
370
+
371
+ def forward(self, x):
372
+ return self.mlp(x)
373
+
374
+ class EncodeModel(nn.Module):
375
+ def __init__(self, backbone_name, projection_dim, normalize):
376
+ super().__init__()
377
+
378
+ self.encoder = AutoModel.from_pretrained(backbone_name)
379
+ hidden_size = self.encoder.config.hidden_size
380
+
381
+ self.proj = MLP(hidden_size, hidden_size, projection_dim)
382
+ self.info_proj = MLP(hidden_size, hidden_size, projection_dim)
383
+
384
+ self.normalize = normalize
385
+
386
+ def embed_query(self, input_ids, attention_mask):
387
+ B, n_parts, L = input_ids.shape
388
+
389
+ input_ids = input_ids.view(-1, L)
390
+ attention_mask = attention_mask.view(-1, L)
391
+
392
+ outputs = self.encoder(
393
+ input_ids=input_ids,
394
+ attention_mask=attention_mask
395
+ )
396
+
397
+ hidden = outputs.last_hidden_state
398
+ hidden = hidden.view(B, n_parts, L, -1)
399
+
400
+ cls = hidden[:, :, 0].mean(dim=1)
401
+
402
+ embed = self.proj(cls)
403
+
404
+ return embed
405
+
406
+ def embed_info(self, input_ids, attention_mask):
407
+ B, n_parts, L = input_ids.shape
408
+
409
+ input_ids = input_ids.view(-1, L)
410
+ attention_mask = attention_mask.view(-1, L)
411
+
412
+ outputs = self.encoder(
413
+ input_ids=input_ids,
414
+ attention_mask=attention_mask
415
+ )
416
+
417
+ hidden = outputs.last_hidden_state
418
+ hidden = hidden.view(B, n_parts, L, -1)
419
+
420
+ cls = hidden[:, :, 0].mean(dim=1)
421
+
422
+ embed = self.info_proj(cls)
423
+
424
+ return embed
425
+
426
+ def embed_doc(self, input_ids, attention_mask):
427
+ B, K, n_parts, L = input_ids.shape
428
+
429
+ input_ids = input_ids.view(-1, L)
430
+ attention_mask = attention_mask.view(-1, L)
431
+
432
+ outputs = self.encoder(
433
+ input_ids=input_ids,
434
+ attention_mask=attention_mask
435
+ )
436
+
437
+ hidden = outputs.last_hidden_state
438
+ hidden = hidden.view(B, K, n_parts, L, -1)
439
+
440
+ cls = hidden[:, :, :, 0].mean(dim=2)
441
+
442
+ embed = self.proj(cls)
443
+
444
+ return embed
445
+
446
+ def forward(
447
+ self,
448
+ text_input_ids,
449
+ text_attn_mask,
450
+ info_input_ids,
451
+ info_attn_mask,
452
+ doc_input_ids,
453
+ doc_attn_mask
454
+ ):
455
+ q_emb = self.embed_query(text_input_ids, text_attn_mask)
456
+ i_emb = self.embed_info(info_input_ids, info_attn_mask)
457
+ d_emb = self.embed_doc(doc_input_ids, doc_attn_mask)
458
+
459
+ if self.normalize:
460
+ q_emb = F.normalize(q_emb, dim=-1)
461
+ i_emb = F.normalize(i_emb, dim=-1)
462
+ d_emb = F.normalize(d_emb, dim=-1)
463
+
464
+ q_emb = q_emb + 0.1 * i_emb
465
+
466
+ if self.normalize:
467
+ q_emb = F.normalize(q_emb, dim=-1)
468
+
469
+ return q_emb, d_emb
470
+
471
+ def test_model():
472
+ model = nn.DataParallel(EncodeModel('vinai/phobert-base', 256, True)).to(device)
473
+ model.eval()
474
+
475
+ bz = 32
476
+ vocab_size = 1000
477
+ qi = torch.randint(0, vocab_size, (bz, 1, 256)).to(device)
478
+ qa = torch.ones(bz, 1, 256).to(device)
479
+ di = torch.randint(0, vocab_size, (bz, 5, 2, 256)).to(device)
480
+ da = torch.ones(bz, 5, 2, 256).to(device)
481
+
482
+ st = time.time()
483
+
484
+ with torch.no_grad():
485
+ r = model(qi, qa, qi, qa, di, da)
486
+
487
+ if type(r) == tuple:
488
+ print([r[i].shape for i in range(len(r))])
489
+ else:
490
+ print(r.shape)
491
+ print(time.time() - st)
492
+
493
+ del model, qi, qa, di, da
494
+ torch.cuda.empty_cache()
495
+ gc.collect()
496
+ test_model()
497
+
498
+ # %% [code]
499
+ def configure_optimizers(network, optim_params, scheduler_params):
500
+ try:
501
+ optim_params = copy.copy(optim_params)
502
+ scheduler_params = copy.copy(scheduler_params)
503
+
504
+ optim_name = optim_params.pop('name')
505
+ scheduler_name = scheduler_params.pop('name')
506
+
507
+ optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None)
508
+ scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None)
509
+
510
+ if optimizer_cls is None:
511
+ raise ValueError(f"Optimizer '{optim_name}' is not available!")
512
+
513
+ optimizer = optimizer_cls(network.parameters(), **optim_params)
514
+
515
+ scheduler = None
516
+ if scheduler_params and scheduler_cls: # Chỉ tạo scheduler nếu có tham số
517
+ scheduler = scheduler_cls(optimizer, **scheduler_params)
518
+
519
+ return optimizer, scheduler
520
+
521
+ except KeyError as e:
522
+ raise ValueError(f"Missing {e} in config!!")
523
+
524
+ def freeze(self, model):
525
+ model.eval()
526
+ for param in model.parameters():
527
+ param.requires_grad = False
528
+
529
+ def unfreeze(self, model):
530
+ model.train()
531
+ for param in model.parameters():
532
+ param.requires_grad = True
533
+
534
+ def reduce_batch_size(loader, ratio=0.5):
535
+ new_bs = max(1, int(loader.batch_size * ratio))
536
+
537
+ shuffle = isinstance(loader.sampler, RandomSampler)
538
+
539
+ new_loader = DataLoader(
540
+ dataset=loader.dataset,
541
+ batch_size=new_bs,
542
+ shuffle=shuffle,
543
+ sampler=None if shuffle else loader.sampler,
544
+ num_workers=loader.num_workers,
545
+ collate_fn=loader.collate_fn,
546
+ pin_memory=loader.pin_memory,
547
+ drop_last=loader.drop_last,
548
+ timeout=loader.timeout,
549
+ worker_init_fn=loader.worker_init_fn,
550
+ multiprocessing_context=loader.multiprocessing_context,
551
+ generator=loader.generator,
552
+ prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None,
553
+ persistent_workers=loader.persistent_workers,
554
+ pin_memory_device=loader.pin_memory_device
555
+ )
556
+
557
+ return new_loader
558
+
559
+ def list_to_tuple(x):
560
+ if isinstance(x, (list, tuple)):
561
+ return tuple(list_to_tuple(i) for i in x)
562
+ return x
563
+
564
+ def fmt(x):
565
+ if isinstance(x, float):
566
+ return round(x, 5)
567
+ if isinstance(x, dict):
568
+ return {k: fmt(v) for k, v in x.items()}
569
+ if isinstance(x, list):
570
+ return [fmt(v) for v in x]
571
+ return x
572
+
573
+ class ModelEmaV3Proxy(ModelEmaV3):
574
+ def __getattr__(self, name):
575
+ try:
576
+ return super().__getattr__(name)
577
+ except AttributeError:
578
+ return getattr(self.module, name)
579
+
580
+ class DataParallelProxy(nn.DataParallel):
581
+ def __getattr__(self, name):
582
+ try:
583
+ return super().__getattr__(name)
584
+ except AttributeError:
585
+ attr = getattr(self.module, name)
586
+
587
+ if callable(attr):
588
+ def wrapper(*args, **kwargs):
589
+ return self._parallel_apply_method(name, *args, **kwargs)
590
+ return wrapper
591
+
592
+ return attr
593
+
594
+ def _parallel_apply_method(self, method_name, *inputs, **kwargs):
595
+ if not self.device_ids:
596
+ return getattr(self.module, method_name)(*inputs, **kwargs)
597
+
598
+ inputs_scattered, kwargs_scattered = self.scatter(inputs, kwargs, self.device_ids)
599
+
600
+ replicas = self.replicate(self.module, self.device_ids)
601
+
602
+ outputs = self.parallel_apply(
603
+ [getattr(replica, method_name) for replica in replicas],
604
+ inputs_scattered,
605
+ kwargs_scattered
606
+ )
607
+
608
+ return self.gather(outputs, self.output_device)
609
+
610
+ class Trainer:
611
+ def __init__(
612
+ self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0,
613
+ logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu',
614
+ schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True
615
+ ):
616
+ self.ema_net = None
617
+
618
+ self.training_time = self._time_str_to_seconds(training_time)
619
+ self.mode = eval_mode
620
+ self.topk = topk
621
+ self.device = device
622
+ self.logging = logging if logging < epochs else 1
623
+ self.logging_file = logging_file
624
+ self.checkpoints_dir = checkpoints_dir
625
+ self.early_stopping = early_stopping
626
+ self.eval_from_ratio = eval_from_ratio
627
+ self.eval_every = eval_every
628
+ self.save_name = save_name
629
+ self.save_best = save_best
630
+ self.save_last = save_last
631
+ self.return_best = return_best
632
+ self.return_last = return_last
633
+ self.max_grad_norm = max_grad_norm
634
+ self.schedule_in_step = schedule_in_step
635
+ self.use_ema = use_ema
636
+ self.ema_from_ratio = ema_from_ratio
637
+ self.ema_decay = ema_decay
638
+
639
+ self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]]
640
+ self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0)
641
+
642
+ def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, corpus_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, refresh_every=3):
643
+ if eval_fn is None:
644
+ if self.mode == "max":
645
+ eval_fn = lambda *x: -loss_fn(*x)
646
+ else:
647
+ eval_fn = lambda *x: loss_fn(*x)
648
+
649
+ if torch.cuda.device_count() > 1:
650
+ network = DataParallelProxy(network)
651
+ network = network.to(self.device)
652
+
653
+ if not start_training_time:
654
+ start_training_time = time.time()
655
+
656
+ start_ema = int(epochs * self.ema_from_ratio)
657
+ start_eval = int(epochs * self.eval_from_ratio)
658
+
659
+ if val_loader is None:
660
+ print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!')
661
+ else:
662
+ model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc'
663
+ start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else ''
664
+ print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str)
665
+
666
+ training_log = {}
667
+ for epoch in range(start_epoch, epochs+start_epoch):
668
+ if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema:
669
+ self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
670
+
671
+ try:
672
+ eval_net = self.ema_net if (self.use_ema and self.ema_net is not None) else network
673
+ if (epoch - start_epoch) % refresh_every == 0:
674
+ encoded_docs = self._get_encoded_docs(eval_net, corpus_loader)
675
+ print(f"[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Refresh Encoded Doc (refresh_every={refresh_every})!")
676
+ elif (epoch - start_epoch - start_eval) % self.eval_every == 0 and epoch - start_epoch >= start_eval:
677
+ encoded_docs = self._get_encoded_docs(eval_net, corpus_loader)
678
+ print(f"[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Refresh Encoded Doc (eval_every={self.eval_every})!")
679
+
680
+ train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, encoded_docs)
681
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
682
+ logging_dict.update(train_loss_epoch_dict)
683
+
684
+ if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0:
685
+ val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, encoded_docs)
686
+ update = self._update_best_network(eval_net, val_score, epoch)
687
+ logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update})
688
+ logging_dict.update(val_score_dict)
689
+ if not self.schedule_in_step and scheduler:
690
+ scheduler.step()
691
+
692
+ except RuntimeError as e:
693
+ if "out of memory" in str(e).lower():
694
+ print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...")
695
+ torch.cuda.empty_cache()
696
+ gc.collect()
697
+ if torch.cuda.is_available():
698
+ torch.cuda.synchronize()
699
+ print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared")
700
+
701
+ train_loader = reduce_batch_size(train_loader, ratio=0.5)
702
+ if val_loader is not None:
703
+ val_loader = reduce_batch_size(val_loader, ratio=0.5)
704
+
705
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')}
706
+ else:
707
+ raise
708
+
709
+ training_log[epoch] = logging_dict
710
+ if self.is_early_stopping(epoch):
711
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...')
712
+ break
713
+ if self.logging:
714
+ if epoch % self.logging == 0:
715
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict))
716
+ else:
717
+ print(f'{epoch}...', end=' ')
718
+
719
+ if self._at_time_limit(start_training_time):
720
+ print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}')
721
+ break
722
+
723
+ if self.logging_file:
724
+ os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True)
725
+ with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f:
726
+ f.write(json.dumps(training_log))
727
+
728
+ if self.use_ema and self.ema_net is not None:
729
+ self._save_state_dict(self.ema_net.module)
730
+ else:
731
+ self._save_state_dict(network)
732
+ print(f'[Trainer CallBack] 📢 Kết thúc training.\n')
733
+
734
+ best_model, last_model = None, None
735
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
736
+ if self.return_best :
737
+ best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict()
738
+ best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()}
739
+ if self.return_last:
740
+ last_model = eval_net.state_dict()
741
+ last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()}
742
+
743
+ del network
744
+ torch.cuda.empty_cache()
745
+ gc.collect()
746
+ return training_log, best_model, last_model
747
+
748
+ def _time_str_to_seconds(self, time_str):
749
+ days, hours, minutes, seconds = map(int, time_str.split(":"))
750
+ return days * 86400 + hours * 3600 + minutes * 60 + seconds
751
+
752
+ def _update_best_network(self, network, val_score, epoch):
753
+ topk = max(1, self.topk)
754
+ self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}])
755
+ self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk]
756
+ if val_score in [x[0] for x in self.best_stage]:
757
+ return True
758
+ return False
759
+
760
+ def is_early_stopping(self, epoch):
761
+ if self.best_stage[0][1] is None:
762
+ return False
763
+ if not self.early_stopping:
764
+ return False
765
+ return epoch - self.best_stage[0][1] >= self.early_stopping
766
+
767
+ def _at_time_limit(self, start_training_time):
768
+ return time.time() - start_training_time >= self.training_time
769
+
770
+ def _save_state_dict(self, network):
771
+ if self.topk <= 0:
772
+ return
773
+
774
+ if self.save_best:
775
+ for r in range(self.topk):
776
+ os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True)
777
+
778
+ for rank, (score, epoch, state_dict) in enumerate(self.best_stage):
779
+ if state_dict is None:
780
+ continue
781
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()}
782
+ torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth')
783
+ if self.save_last:
784
+ os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True)
785
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()}
786
+ torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth')
787
+
788
+ def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn, encoded_docs):
789
+ network.train()
790
+ total_loss = 0
791
+ total_loss_dict = {}
792
+ for batch_idx, batch in enumerate(train_loader):
793
+ optimizer.zero_grad()
794
+ with torch.autocast(device_type=self.device, dtype=torch.float16):
795
+ loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn, encoded_docs)
796
+
797
+ for k, v in loss_dict.items():
798
+ t = total_loss_dict.get(k, 0)
799
+ total_loss_dict[k] = t + v
800
+ self.grad_scaler.scale(loss).backward()
801
+ self.grad_scaler.unscale_(optimizer)
802
+ grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm)
803
+ # print(grad_norm) # Bỏ cmt dòng này để biết nên chọn max_grad_norm bằng bao nhiêu...
804
+ self.grad_scaler.step(optimizer)
805
+ self.grad_scaler.update()
806
+ if self.schedule_in_step and scheduler:
807
+ scheduler.step()
808
+ if self.use_ema and self.ema_net is not None:
809
+ self.ema_net.update(network)
810
+ total_loss += loss
811
+ return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()}
812
+
813
+ def _eval_epoch(self, network, val_loader, eval_fn, encoded_docs):
814
+ network.eval()
815
+ total_score = 0.0
816
+ total_score_dict = {}
817
+ object_lists = None # sẽ init sau
818
+
819
+ with torch.no_grad():
820
+ for batch_idx, batch in enumerate(val_loader):
821
+ score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, encoded_docs)
822
+ total_score += score
823
+
824
+ for k, v in score_dict.items():
825
+ t = total_score_dict.get(k, 0)
826
+ total_score_dict[k] = t + v
827
+
828
+ if objects:
829
+ if object_lists is None:
830
+ object_lists = [[] for _ in range(len(objects))]
831
+
832
+ for i, obj in enumerate(objects):
833
+ object_lists[i].append(obj.detach())
834
+
835
+ if object_lists is not None:
836
+ object_arrays = [
837
+ torch.concat(obj_list, dim=0).cpu().numpy()
838
+ for obj_list in object_lists
839
+ ]
840
+ else:
841
+ object_arrays = []
842
+
843
+ return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays
844
+
845
+ def _get_encoded_docs(self, network, corpus_loader):
846
+ network.eval()
847
+ with torch.no_grad():
848
+ encoded_docs = []
849
+ for batch_idx, batch in enumerate(corpus_loader):
850
+ input_ids = batch['input_ids'].to(self.device)
851
+ attn_mask = batch['attn_mask'].to(self.device)
852
+ encoded_doc = network.embed_doc(input_ids, attn_mask)
853
+ if network.normalize:
854
+ encoded_doc = F.normalize(encoded_doc, dim=-1)
855
+ encoded_docs.append(encoded_doc)
856
+ encoded_docs = torch.concat(encoded_docs, dim=0).squeeze(1)
857
+ return encoded_docs
858
+
859
+ def _cal_loss(self, network, batch, batch_idx, loss_fn, encoded_docs):
860
+ # Bạn cần override _cal_loss để tính loss
861
+ text_input_ids = batch['text_input_ids'].to(self.device)
862
+ text_attn_mask = batch['text_attn_mask'].to(self.device)
863
+ info_input_ids = batch['info_input_ids'].to(self.device)
864
+ info_attn_mask = batch['info_attn_mask'].to(self.device)
865
+ pos_idxes = batch['pos_idxes'].to(self.device)
866
+ pos_mask = batch['pos_mask'].to(self.device)
867
+ neg_idxes = batch['neg_idxes'].to(self.device)
868
+
869
+ encoded_text = network.embed_query(text_input_ids, text_attn_mask)
870
+ encoded_info = network.embed_info(info_input_ids, info_attn_mask)
871
+ encoded_text = encoded_text + encoded_info
872
+
873
+ if network.normalize:
874
+ encoded_text = F.normalize(encoded_text, dim=-1)
875
+
876
+ encoded_pos = encoded_docs[pos_idxes]
877
+ encoded_neg = encoded_docs[neg_idxes]
878
+
879
+ loss_dict = loss_fn(
880
+ encoded_text, encoded_pos, encoded_neg, pos_mask,
881
+ )
882
+ return loss_dict['total'], loss_dict
883
+
884
+ def _cal_val_score(self, network, batch, batch_idx, eval_fn, encoded_docs):
885
+ # Bạn cần override _cal_val_score để tính val score, list bên cạnh là để trả về y hay pred gì đó (nếu cần)
886
+ text_input_ids = batch['text_input_ids'].to(self.device)
887
+ text_attn_mask = batch['text_attn_mask'].to(self.device)
888
+ info_input_ids = batch['info_input_ids'].to(self.device)
889
+ info_attn_mask = batch['info_attn_mask'].to(self.device)
890
+ gt_pos_idxes = batch['gt_pos_idxes']
891
+
892
+ encoded_text = network.embed_query(text_input_ids, text_attn_mask)
893
+ encoded_info = network.embed_info(info_input_ids, info_attn_mask)
894
+ encoded_text = encoded_text + encoded_info
895
+
896
+ if network.normalize:
897
+ encoded_text = F.normalize(encoded_text, dim=-1)
898
+ scores = torch.matmul(encoded_text, encoded_docs.T)
899
+
900
+ topk_scores, topk_indices = torch.topk(scores, k=10, dim=-1) # [B, K]
901
+ pred_topk = [idx[logit > 0].tolist() for idx, logit in zip(topk_indices, topk_scores)]
902
+
903
+ pred_topk = list_to_tuple(pred_topk)
904
+ gt_pos_idxes = list_to_tuple(gt_pos_idxes)
905
+ score_dict = eval_fn(pred_topk, gt_pos_idxes)
906
+ return score_dict['f2'], score_dict, []
907
+
908
+ # %% [code]
909
+ def tokenize_to_parts(text, tokenizer, max_len, max_n_parts):
910
+ # Tokenize với overflow để chia thành nhiều đoạn
911
+ enc = tokenizer(
912
+ text,
913
+ max_length=max_len*max_n_parts,
914
+ truncation=True,
915
+ padding="max_length",
916
+ return_overflowing_tokens=True,
917
+ return_tensors="pt"
918
+ )
919
+
920
+ input_ids = enc["input_ids"].reshape(max_n_parts, max_len) # (n_parts, max_len)
921
+ attn_mask = enc["attention_mask"].reshape(max_n_parts, max_len) # (n_parts, max_len)
922
+
923
+ return input_ids, attn_mask
924
+
925
+ class LawRetrievalDataset(Dataset):
926
+ def __init__(self, all_data, using_idxes, corpus_dict, tokenizer, max_len, max_n_parts, n_negs):
927
+ super().__init__()
928
+
929
+ self.all_data = all_data
930
+ self.using_idxes = using_idxes
931
+ self.tokenizer = tokenizer
932
+ self.max_len = max_len
933
+ self.max_n_parts = max_n_parts
934
+ self.n_negs = n_negs
935
+
936
+ # ===== BUILD CORPUS =====
937
+ idx = 0
938
+ self.corpus_list = []
939
+ self.corpus_dict = {}
940
+
941
+ for doc_name, articles_dict in corpus_dict.items():
942
+ self.corpus_dict[doc_name] = {}
943
+ for article_idx, content in articles_dict.items():
944
+ self.corpus_list.append([doc_name, article_idx, content])
945
+ self.corpus_dict[doc_name][article_idx] = {
946
+ 'content': content,
947
+ 'idx': idx
948
+ }
949
+ idx += 1
950
+
951
+ def __len__(self):
952
+ return len(self.using_idxes)
953
+
954
+ def __getitem__(self, idx):
955
+ ridx = self.using_idxes[idx]
956
+ data = self.all_data[ridx]
957
+
958
+ query_text = data['text']
959
+
960
+ text_input_ids, text_attn_mask = tokenize_to_parts(
961
+ query_text, self.tokenizer, self.max_len, 1
962
+ )
963
+
964
+ info = ['[entities]']
965
+ for entity in data['entities']:
966
+ info.extend([f'[{entity["label"]}]', entity["name"]])
967
+ info += ['[actions]']
968
+ for action in data['actions']:
969
+ info.extend([f'[{action["label"]}]', action["trigger"]])
970
+ for arg in action['arguments']:
971
+ info.extend([f'[{arg["role"]}]', arg["argument"]])
972
+ info += ['[issues]']
973
+ for issue in data['issues']:
974
+ info.extend([f'[{issue["label"]}]', issue["trigger"]])
975
+ for arg in issue['arguments']:
976
+ info.extend([f'[{arg["role"]}]', arg["argument"]])
977
+ info_input_ids, info_attn_mask = tokenize_to_parts(
978
+ ' '.join(info), self.tokenizer, self.max_len, 1
979
+ )
980
+
981
+ # ===== POS =====
982
+ gt_pos_idxes = []
983
+ hard_names = []
984
+ for law in data['relevant_law']:
985
+ name = law['doc']
986
+ art = law['art']
987
+ gt_pos_idxes.append(self.corpus_dict[name][art]['idx'])
988
+ if name not in hard_names:
989
+ hard_names.append(name)
990
+
991
+ pos_idxes = torch.tensor(gt_pos_idxes, dtype=torch.long)
992
+ pos_mask = torch.ones(len(pos_idxes))
993
+
994
+ # ===== NEG =====
995
+ hard_neg_idxes = []
996
+ for name in hard_names:
997
+ for content in self.corpus_dict[name].values():
998
+ if content['idx'] in gt_pos_idxes:
999
+ continue
1000
+ hard_neg_idxes.append(content['idx'])
1001
+
1002
+ easy_neg_idxes = list(range(len(self.corpus_list)))
1003
+ for i in gt_pos_idxes + hard_neg_idxes:
1004
+ if i in easy_neg_idxes:
1005
+ easy_neg_idxes.remove(i)
1006
+
1007
+ n_hards = min(len(hard_neg_idxes), self.n_negs // 2)
1008
+ neg_idxes = random.sample(hard_neg_idxes, n_hards) + random.sample(easy_neg_idxes, self.n_negs - n_hards)
1009
+ neg_idxes = torch.tensor(neg_idxes, dtype=torch.long)
1010
+
1011
+ return {
1012
+ 'text_input_ids': text_input_ids,
1013
+ 'text_attn_mask': text_attn_mask,
1014
+ 'info_input_ids': text_input_ids,
1015
+ 'info_attn_mask': text_attn_mask,
1016
+ 'gt_pos_idxes': gt_pos_idxes,
1017
+ 'pos_idxes': pos_idxes,
1018
+ 'pos_mask': pos_mask,
1019
+ 'neg_idxes': neg_idxes,
1020
+ }
1021
+
1022
+ class CorpusDataset(Dataset):
1023
+ def __init__(self, corpus_dict, tokenizer, max_len, max_n_parts):
1024
+ super().__init__()
1025
+ self.tokenizer = tokenizer
1026
+ self.max_len = max_len
1027
+ self.max_n_parts = max_n_parts
1028
+
1029
+ idx = 0
1030
+ self.corpus_list = []
1031
+ self.corpus_dict = {}
1032
+ for doc_name, articles_dict in corpus_dict.items():
1033
+ self.corpus_dict[doc_name] = {}
1034
+ for article_idx, content in articles_dict.items():
1035
+ self.corpus_list.append([doc_name, article_idx, content])
1036
+ self.corpus_dict[doc_name][article_idx] = {'content': content, 'idx': idx}
1037
+ idx += 1
1038
+
1039
+ def __len__(self):
1040
+ return len(self.corpus_list)
1041
+
1042
+ def _encode_contexts(self, idxes):
1043
+ all_input_ids, all_attn_mask = [], []
1044
+ for idx in idxes:
1045
+ name = self.corpus_list[idx][0]
1046
+ art = self.corpus_list[idx][1]
1047
+ corpus = self.corpus_dict[name][art]
1048
+ if 'content_input_ids' in corpus and 'content_attn_mask' in corpus:
1049
+ content_input_ids = corpus['content_input_ids']
1050
+ content_attn_mask = corpus['content_attn_mask']
1051
+ else:
1052
+ content = corpus['content']
1053
+ content_input_ids, content_attn_mask = tokenize_to_parts(content, self.tokenizer, self.max_len, self.max_n_parts)
1054
+ corpus['content_input_ids'] = content_input_ids
1055
+ corpus['content_attn_mask'] = content_attn_mask
1056
+
1057
+ all_input_ids.append(content_input_ids)
1058
+ all_attn_mask.append(content_attn_mask)
1059
+
1060
+ all_input_ids = torch.stack(all_input_ids)
1061
+ all_attn_mask = torch.stack(all_attn_mask)
1062
+ return all_input_ids, all_attn_mask
1063
+
1064
+ def __getitem__(self, idx):
1065
+ input_ids, attn_mask = self._encode_contexts([idx])
1066
+
1067
+ return {
1068
+ 'input_ids': input_ids,
1069
+ 'attn_mask': attn_mask,
1070
+ }
1071
+
1072
+ def _pad_batch(tensor_list, pad_value=0):
1073
+ """
1074
+ tensor_list: list of tensors, mỗi tensor shape (Nk, max_n_parts, max_len)
1075
+ return: tensor shape (B, max_Nk, max_n_parts, max_len)
1076
+ """
1077
+ max_Nk = max(t.size(0) for t in tensor_list)
1078
+
1079
+ padded = []
1080
+ for t in tensor_list:
1081
+ Nk = t.size(0)
1082
+
1083
+ if Nk < max_Nk:
1084
+ pad_shape = (max_Nk - Nk, *t.shape[1:])
1085
+ pad_tensor = t.new_full(pad_shape, pad_value)
1086
+ t = torch.cat([t, pad_tensor], dim=0)
1087
+
1088
+ padded.append(t)
1089
+
1090
+ return torch.stack(padded) # (B, max_Nk, max_n_parts, max_len)
1091
+
1092
+ def collate_fn(batch):
1093
+ text_input_ids = torch.stack([b["text_input_ids"] for b in batch])
1094
+ text_attn_mask = torch.stack([b["text_attn_mask"] for b in batch])
1095
+ info_input_ids = torch.stack([b["info_input_ids"] for b in batch])
1096
+ info_attn_mask = torch.stack([b["info_attn_mask"] for b in batch])
1097
+ gt_pos_idxes = [b["gt_pos_idxes"] for b in batch]
1098
+ neg_idxes = torch.stack([b["neg_idxes"] for b in batch])
1099
+
1100
+ pos_idxes = [b["pos_idxes"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1101
+ pos_mask = [b["pos_mask"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1102
+
1103
+ # pad theo Nk
1104
+ pos_idxes = _pad_batch(pos_idxes, pad_value=0).squeeze(-1).squeeze(-1)
1105
+ pos_mask = _pad_batch(pos_mask, pad_value=0).squeeze(-1).squeeze(-1)
1106
+
1107
+ return {
1108
+ 'text_input_ids': text_input_ids,
1109
+ 'text_attn_mask': text_attn_mask,
1110
+ 'info_input_ids': info_input_ids,
1111
+ 'info_attn_mask': info_attn_mask,
1112
+ 'gt_pos_idxes': gt_pos_idxes,
1113
+ 'pos_idxes': pos_idxes,
1114
+ 'pos_mask': pos_mask,
1115
+ 'neg_idxes': neg_idxes,
1116
+ }
1117
+
1118
+ # %% [code]
1119
+ def encode_corpus(state_dicts, network, corpus_loader, device):
1120
+ if torch.cuda.device_count() > 1:
1121
+ network = DataParallelProxy(network)
1122
+ network.to(device)
1123
+ network.eval()
1124
+
1125
+ all_model_embs = []
1126
+ for i, state_dict in enumerate(state_dicts):
1127
+ # ===== load model =====
1128
+ if torch.cuda.device_count() > 1:
1129
+ network.module.load_state_dict(state_dict)
1130
+ else:
1131
+ network.load_state_dict(state_dict)
1132
+
1133
+ encoded_docs = []
1134
+
1135
+ with torch.no_grad():
1136
+ for batch in corpus_loader:
1137
+ input_ids = batch['input_ids'].to(device)
1138
+ attn_mask = batch['attn_mask'].to(device)
1139
+
1140
+ encoded_doc = network.embed_doc(input_ids, attn_mask)
1141
+ if network.normalize:
1142
+ encoded_doc = F.normalize(encoded_doc, dim=-1)
1143
+
1144
+ encoded_docs.append(encoded_doc)
1145
+
1146
+ encoded_docs = torch.concat(encoded_docs, dim=0).squeeze(1) # [N, D]
1147
+ all_model_embs.append(encoded_docs)
1148
+
1149
+ # ===== ensemble =====
1150
+ # stack → [M, N, D]
1151
+ all_model_embs = torch.stack(all_model_embs, dim=0)
1152
+ final_embs = all_model_embs.mean(dim=0) # [N, D]
1153
+
1154
+ return final_embs
1155
+
1156
+ def test(state_dicts, network, test_loader, device, eval_fn, encoded_docs, topks=[5, 10, 15]):
1157
+ if torch.cuda.device_count() > 1:
1158
+ network = DataParallelProxy(network)
1159
+ network.to(device)
1160
+ network.eval()
1161
+
1162
+ per_model_scores = []
1163
+ max_k = max(topks)
1164
+
1165
+ all_scores = []
1166
+ all_logits = []
1167
+ all_gt_pos_idxes = []
1168
+ with torch.no_grad():
1169
+ for batch in test_loader:
1170
+ text_input_ids = batch['text_input_ids'].to(device)
1171
+ text_attn_mask = batch['text_attn_mask'].to(device)
1172
+ info_input_ids = batch['info_input_ids'].to(device)
1173
+ info_attn_mask = batch['info_attn_mask'].to(device)
1174
+ gt_pos_idxes = batch['gt_pos_idxes']
1175
+ all_gt_pos_idxes.extend(gt_pos_idxes)
1176
+
1177
+ list_encoded_texts = []
1178
+ list_logits = []
1179
+
1180
+ for state_dict in state_dicts:
1181
+ # ===== load model =====
1182
+ if torch.cuda.device_count() > 1:
1183
+ network.module.load_state_dict(state_dict)
1184
+ else:
1185
+ network.load_state_dict(state_dict)
1186
+
1187
+ encoded_text = network.embed_query(text_input_ids, text_attn_mask)
1188
+ encoded_info = network.embed_info(info_input_ids, info_attn_mask)
1189
+
1190
+ if network.normalize:
1191
+ encoded_text = F.normalize(encoded_text, dim=-1)
1192
+ encoded_info = F.normalize(encoded_info, dim=-1)
1193
+ encoded_text = encoded_text + 0.1 * encoded_info
1194
+
1195
+ if network.normalize:
1196
+ encoded_text = F.normalize(encoded_text, dim=-1)
1197
+
1198
+ list_encoded_texts.append(encoded_text)
1199
+
1200
+ ensemble_encoded_text = torch.stack(list_encoded_texts, dim=0).mean(dim=0)
1201
+ scores = torch.matmul(ensemble_encoded_text, encoded_docs.T) # B, M
1202
+ all_scores.append(scores)
1203
+
1204
+ all_scores = torch.concat(all_scores, dim=0) # N, M
1205
+
1206
+ topk_scores, topk_indices = torch.topk(all_scores, k=10, dim=-1) # [B, K]
1207
+ pred_topk_full = [idx[logit > 0].tolist() for idx, logit in zip(topk_indices, topk_scores)]
1208
+
1209
+ pred_topk_full = list_to_tuple(pred_topk_full)
1210
+ all_gt_pos_idxes = list_to_tuple(all_gt_pos_idxes)
1211
+
1212
+ final_score = {}
1213
+ for k in topks:
1214
+ pred_topk_k = [p[:k] for p in pred_topk_full]
1215
+ final_score[k] = eval_fn(pred_topk_k, all_gt_pos_idxes)
1216
+
1217
+ return final_score
1218
+
1219
+ # %% [code]
1220
+ with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f:
1221
+ data_train = json.load(f)
1222
+
1223
+ with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f:
1224
+ data_test = json.load(f)
1225
+
1226
+ with open(f'{test_dir}/corpus.json', "r", encoding="utf-8") as f:
1227
+ data_corpus = json.load(f)
1228
+
1229
+ print('Train:', len(data_train))
1230
+ print('Test:', len(data_test))
1231
+ print('Corpus:', len(data_corpus))
1232
+
1233
+ # %% [code]
1234
+ entity_types = sorted(list(set([e['label'] for d in data_train + data_test for e in d['entities']])))
1235
+ action_trigger_types = sorted(list(set([e['label'] for d in data_train + data_test for e in d['actions']])))
1236
+ issue_trigger_types = sorted(list(set([e['label'] for d in data_train + data_test for e in d['issues']])))
1237
+ argument_types = sorted(list(set([a['role'] for d in data_train + data_test for e in d['issues'] for a in e['arguments']])))
1238
+
1239
+ # %% [code]
1240
+ if debug_only:
1241
+ data_train = data_train[:20]
1242
+ data_test = data_test[:20]
1243
+
1244
+ print('Train:', len(data_train))
1245
+ print('Test:', len(data_test))
1246
+
1247
+ # %% [code]
1248
+ tokenizer = AutoTokenizer.from_pretrained(backbone_model_name)
1249
+ special_tokens = ['[entities]', '[actions]', '[issues]'] + [f'[{t}]' for t in entity_types + action_trigger_types + issue_trigger_types + argument_types]
1250
+ tokenizer.add_special_tokens({'additional_special_tokens': sorted(list(set(special_tokens)))})
1251
+
1252
+ # %% [code]
1253
+ print('Experiment name:', state_dict_save_name)
1254
+
1255
+ # %% [code]
1256
+ if not test_only:
1257
+ full_idxes = np.array(range(len(data_train)))
1258
+ training_logs, best_models, last_models = [], [], []
1259
+ start_training_time = time.time()
1260
+ for seed in SEEDS:
1261
+ kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed)
1262
+ generator = torch.Generator()
1263
+ generator.manual_seed(seed)
1264
+
1265
+ corpusset = CorpusDataset(data_corpus, tokenizer, **corpus_memory_params)
1266
+ corpus_loader = DataLoader(corpusset, generator=generator, **val_loader_params)
1267
+ for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)):
1268
+ if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx:
1269
+ continue
1270
+ set_seed(seed)
1271
+
1272
+ train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx]
1273
+
1274
+ trainset = LawRetrievalDataset(data_train, train_idxes, data_corpus, tokenizer, **train_memory_params)
1275
+ valset = LawRetrievalDataset(data_train, val_idxes, data_corpus, tokenizer, **val_memory_params)
1276
+
1277
+ train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params)
1278
+ val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params)
1279
+
1280
+ my_model = EncodeModel(
1281
+ **model_params
1282
+ )
1283
+ my_model.encoder.resize_token_embeddings(len(tokenizer))
1284
+ total_params = sum(p.numel() for p in my_model.parameters())
1285
+ print(f"Total params: {total_params:,}")
1286
+
1287
+ # optimizer, scheduler = configure_optimizers(my_model, optim_params, scheduler_params)
1288
+ encoder_params = set(map(id, my_model.encoder.parameters()))
1289
+ other_params = [
1290
+ p for p in my_model.parameters()
1291
+ if id(p) not in encoder_params
1292
+ ]
1293
+ optimizer = optim.AdamW([
1294
+ {"params": my_model.encoder.parameters(), "lr": 5e-4},
1295
+ {"params": other_params}
1296
+ ], lr=5e-4)
1297
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
1298
+
1299
+ loss_fn = CustomLoss(
1300
+ **loss_func_params
1301
+ )
1302
+ eval_fn = CustomEvalFn(**eval_func_params)
1303
+ trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}'
1304
+ trainer = Trainer(**trainer_params)
1305
+
1306
+ print(f'Start Training Fold {fold_idx}...')
1307
+ training_log, best_model, last_model = trainer.fit(
1308
+ my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, corpus_loader, eval_fn,
1309
+ start_epoch=1, start_training_time=start_training_time, refresh_every=2,
1310
+ )
1311
+
1312
+ training_logs.append(training_log)
1313
+ best_models.append(best_model)
1314
+ last_models.append(last_model)
1315
+
1316
+ # %% [code]
1317
+ def load_all_state_dicts(folder):
1318
+ files = []
1319
+
1320
+ for file in os.listdir(folder):
1321
+ if file.endswith(".pt") or file.endswith(".pth"):
1322
+ m = re.search(r"f(\d+)", file) # tìm f<số>
1323
+ if m:
1324
+ fold = int(m.group(1))
1325
+ files.append((fold, file))
1326
+
1327
+ # sort theo fold
1328
+ files.sort(key=lambda x: x[0])
1329
+
1330
+ state_dicts = []
1331
+ for fold, file in files:
1332
+ path = os.path.join(folder, file)
1333
+ print(f"Loading fold {fold}: {file}")
1334
+ state_dict = torch.load(path, map_location="cpu")
1335
+ state_dicts.append(state_dict)
1336
+
1337
+ return state_dicts
1338
+
1339
+ if test_only:
1340
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"])
1341
+ get_ipython().system('rm -rf .cache .gitattributes')
1342
+
1343
+ best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s")
1344
+ last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts")
1345
+
1346
+ # %% [code]
1347
+ os.makedirs(f'{checkpoints_dir}/results', exist_ok=True)
1348
+ testset = LawRetrievalDataset(data_test, range(len(data_test)), data_corpus, tokenizer, **val_memory_params)
1349
+ generator = torch.Generator()
1350
+ test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params)
1351
+ eval_fn = CustomEvalFn(**eval_func_params)
1352
+ my_model = EncodeModel(
1353
+ **model_params
1354
+ )
1355
+ my_model.encoder.resize_token_embeddings(len(tokenizer))
1356
+ total_params = sum(p.numel() for p in my_model.parameters())
1357
+ print(f"Total params: {total_params:,}")
1358
+
1359
+ # %% [code]
1360
+ start_time = time.time()
1361
+ encoded_docs = encode_corpus(best_models, my_model, corpus_loader, device)
1362
+ best_score = test(best_models, my_model, test_loader, device, eval_fn, encoded_docs)
1363
+
1364
+ encoded_docs = encode_corpus(last_models, my_model, corpus_loader, device)
1365
+ last_score = test(last_models, my_model, test_loader, device, eval_fn, encoded_docs)
1366
+
1367
+ result_test = {"Best model": best_score, "Last model": last_score}
1368
+
1369
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f:
1370
+ json.dump(result_test, f, ensure_ascii=False, indent=2)
1371
+
1372
+ print('Test:', time.time() - start_time, 's --> Done!')
1373
+
1374
+ # %% [code]
1375
+ def dict_to_df(data):
1376
+ row_tuples = []
1377
+ row_values = []
1378
+
1379
+ # ===== lấy model đầu tiên =====
1380
+ first_model = next(iter(data.values()))
1381
+
1382
+ # ===== eval keys =====
1383
+ eval_keys = list(first_model.keys())
1384
+
1385
+ # ===== tự lấy metrics =====
1386
+ first_eval = next(iter(first_model.values()))
1387
+ metrics = list(first_eval.keys())
1388
+
1389
+ for eval_key in eval_keys:
1390
+ row_tuples.append(eval_key)
1391
+
1392
+ row = {}
1393
+
1394
+ for model_name, model_data in data.items():
1395
+ for metric in metrics:
1396
+ row[(model_name, metric)] = model_data[eval_key][metric]
1397
+
1398
+ row_values.append(row)
1399
+
1400
+ # ===== DataFrame =====
1401
+ df = pd.DataFrame(row_values)
1402
+
1403
+ # ===== MultiIndex columns =====
1404
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
1405
+
1406
+ # ===== Index =====
1407
+ df.index = pd.Index(
1408
+ row_tuples,
1409
+ name="evaluation"
1410
+ )
1411
+
1412
+ # ===== Sort =====
1413
+ sort_keys = []
1414
+
1415
+ for model_name in data.keys():
1416
+ for metric in ["f1", "f2", "mrr", "recall", "precision"]:
1417
+ key = (model_name, metric)
1418
+
1419
+ if key in df.columns:
1420
+ sort_keys.append(key)
1421
+
1422
+ if sort_keys:
1423
+ df = df.sort_values(
1424
+ by=sort_keys,
1425
+ ascending=False
1426
+ )
1427
+
1428
+ return df
1429
+
1430
+ result_test_df = dict_to_df(result_test)
1431
+ result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx")
1432
+ result_test_df
1433
+
1434
+ # %% [code]
1435
+ key = ("Best model", "f2")
1436
+ result_test_df_best = result_test_df.sort_values(by=key, ascending=False).groupby(level="evaluation").head(1)
1437
+ result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx")
1438
+ result_test_df_best
1439
+
1440
+ # %% [code]
1441
+ def get_avg_best_score(logs):
1442
+ return float(np.mean([list(log.values())[-1]['best_score'] for log in logs]))
1443
+
1444
+ def get_avg_log(logs, epochs):
1445
+ avg_log = {}
1446
+
1447
+ for epoch in range(1, epochs + 1):
1448
+ val_score = 0.0
1449
+ train_loss = 0.0
1450
+ n_eval = 0
1451
+
1452
+ for idx in range(len(logs)):
1453
+ log = logs[idx].get(epoch, logs[idx].get(str(epoch)))
1454
+ if log is None:
1455
+ continue
1456
+
1457
+ val_score += log.get('val_score', 0.0)
1458
+ train_loss += log.get('train_loss', 0.0)
1459
+ n_eval += 1
1460
+
1461
+ if n_eval == 0:
1462
+ continue
1463
+
1464
+ avg_log[epoch] = {
1465
+ 'train_loss': train_loss / n_eval,
1466
+ 'val_score': val_score / n_eval if val_score != 0 else float('inf')
1467
+ }
1468
+
1469
+ return avg_log
1470
+
1471
+ def parse_label_key(label: str):
1472
+ try:
1473
+ first = float(label.split('_', 1)[0]) # số đầu: trước dấu _
1474
+ last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0])
1475
+ return first, last
1476
+ except:
1477
+ return (0, 0)
1478
+
1479
+ def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)):
1480
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
1481
+
1482
+ # ===== Plot Train Loss =====
1483
+ for name, log in logs_dict.items():
1484
+ epochs = sorted(log.keys())
1485
+ train_loss = [log[e]['train_loss'] for e in epochs]
1486
+ axes[0].plot(epochs, train_loss, label=name)
1487
+
1488
+ axes[0].set_xlabel('Epoch')
1489
+ axes[0].set_ylabel('Train Loss')
1490
+ axes[0].set_title('Training Loss')
1491
+ axes[0].grid(True)
1492
+
1493
+ # ===== Plot Validation Score =====
1494
+ for name, log in logs_dict.items():
1495
+ epochs = sorted(log.keys())
1496
+ val_score = [log[e]['val_score'] for e in epochs]
1497
+ axes[1].plot(epochs, val_score, label=name)
1498
+
1499
+ axes[1].set_xlabel('Epoch')
1500
+ axes[1].set_ylabel('Validation Score')
1501
+ axes[1].set_title('Validation Score')
1502
+ axes[1].grid(True)
1503
+
1504
+ # ===== Shared Legend =====
1505
+ handles, labels = axes[0].get_legend_handles_labels()
1506
+ pairs = list(zip(handles, labels))
1507
+ pairs_sorted = sorted(
1508
+ pairs,
1509
+ key=lambda x: parse_label_key(x[1])
1510
+ )
1511
+ handles_sorted, labels_sorted = zip(*pairs_sorted)
1512
+
1513
+ axes[0].legend(
1514
+ handles_sorted,
1515
+ labels_sorted,
1516
+ loc='center left',
1517
+ bbox_to_anchor=(1.01, 0.5),
1518
+ borderaxespad=0.
1519
+ )
1520
+
1521
+ plt.tight_layout(rect=[0, 0, 1, 1])
1522
+
1523
+ if save_path is not None:
1524
+ os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None
1525
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
1526
+
1527
+ plt.show()
1528
+
1529
+ # %% [code]
1530
+ if not test_only:
1531
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*lr*.json"], ignore_patterns=[])
1532
+ get_ipython().system('rm -rf .cache .gitattributes')
1533
+
1534
+ # %% [code]
1535
+ if not test_only:
1536
+ experiments = {}
1537
+ for experiment in os.listdir(pretrained_dir):
1538
+ experiment_logs = []
1539
+ try:
1540
+ for seed in SEEDS:
1541
+ for fold_idx in range(nfolds):
1542
+ with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f:
1543
+ experiment_log = json.load(f)
1544
+ experiment_logs.append(experiment_log)
1545
+ except:
1546
+ pass
1547
+ experiments[experiment] = get_avg_log(experiment_logs, 1000)
1548
+ experiments[state_dict_save_name] = get_avg_log(training_logs, 1000)
1549
+
1550
+ # %% [code]
1551
+ if not test_only:
1552
+ score = get_avg_best_score(training_logs)
1553
+ state_dict_save_name, score
1554
+
1555
+ # %% [code]
1556
+ if not test_only:
1557
+ plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5))
1558
+
5_lr_up_lr_6/lasts/5_lr_up_lr_6_s26092004_f0_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba8d56e59d51e1ad70eef1aeaf490feab6a1e3444be712ad5d101a64051c922e
3
+ size 546476760
5_lr_up_lr_6/logs/5_lr_up_lr_6_log_plot.jpg ADDED

Git LFS Details

  • SHA256: c5f0912f9c5058cf4df94b8008717ddb470e7c0ab8aacad8e55fcb17e5dab356
  • Pointer size: 131 Bytes
  • Size of remote file: 505 kB
5_lr_up_lr_6/logs/5_lr_up_lr_6_s26092004_f0_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [0.0005, 0.0005], "train_loss": 5.792922019958496, "total": 5.792922208141722, "contrastive_loss": 4.634641513377926, "triplet_loss": 0.24289297658862877}, "2": {"lr": [0.0004969282409784868, 0.0004969282409784868], "train_loss": 5.726807594299316, "total": 5.726807457148829, "contrastive_loss": 4.611471424932065, "triplet_loss": 0.2334866220735786}, "3": {"lr": [0.0004877886008156408, 0.0004877886008156408], "train_loss": 5.654202938079834, "total": 5.654202974759615, "contrastive_loss": 4.65382329117893, "triplet_loss": 0.19147157190635453}, "4": {"lr": [0.00047280612778499774, 0.00047280612778499774], "train_loss": 5.654203414916992, "total": 5.65420338302153, "contrastive_loss": 4.6538686082514635, "triplet_loss": 0.19147157190635453}, "5": {"lr": [0.00045234974009654937, 0.00045234974009654937], "train_loss": 5.654292106628418, "total": 5.654291975857023, "contrastive_loss": 4.65389514527592, "triplet_loss": 0.19147157190635453}, "6": {"lr": [0.00042692314190604356, 0.00042692314190604356], "train_loss": 5.654357433319092, "total": 5.6543577060252925, "contrastive_loss": 4.653764093201296, "triplet_loss": 0.19147157190635453}, "7": {"lr": [0.00039715242044697206, 0.00039715242044697206], "train_loss": 5.654600620269775, "total": 5.654600621864549, "contrastive_loss": 4.653791855011496, "triplet_loss": 0.19147157190635453}, "8": {"lr": [0.00036377062968501693, 0.00036377062968501693], "train_loss": 5.654788970947266, "total": 5.654789238869147, "contrastive_loss": 4.653832272941053, "triplet_loss": 0.19147157190635453, "val_score": 0.0016072584822584824, "best_score": 0.0016072584822584824, "new_best_model": true, "precision": 0.0006250000000000002, "recall": 0.00625, "f2": 0.0016072584822584824, "mrr": 0.0025850694444444437}, "9": {"lr": [0.0003275997400965494, 0.0003275997400965494], "train_loss": 5.654257774353027, "total": 5.654257681856187, "contrastive_loss": 4.653806552440426, "triplet_loss": 0.19147157190635453, "val_score": 0.0003422040922040922, "best_score": 0.0016072584822584824, "new_best_model": false, "precision": 0.00016666666666666666, "recall": 0.0016666666666666668, "f2": 0.0003422040922040922, "mrr": 0.00033556547619047623}, "10": {"lr": [0.00028953039902753766, 0.00028953039902753766], "train_loss": 5.653966426849365, "total": 5.653966591110995, "contrastive_loss": 4.653811451583403, "triplet_loss": 0.1913670568561873, "val_score": 0.0006956006956006955, "best_score": 0.0016072584822584824, "new_best_model": false, "precision": 0.0003333333333333334, "recall": 0.003125, "f2": 0.0006956006956006955, "mrr": 0.0005031415343915344}, "11": {"lr": [0.0002505, 0.0002505], "train_loss": 5.653940200805664, "total": 5.653940054086538, "contrastive_loss": 4.653813492892977, "triplet_loss": 0.19147157190635453, "val_score": 0.0009017487142487143, "best_score": 0.0016072584822584824, "new_best_model": false, "precision": 0.00039583333333333343, "recall": 0.00375, "f2": 0.0009017487142487143, "mrr": 0.001205109126984127}, "12": {"lr": [0.00021146960097246246, 0.00021146960097246246], "train_loss": 5.654030799865723, "total": 5.654030688231606, "contrastive_loss": 4.653771441915761, "triplet_loss": 0.19147157190635453, "val_score": 0.0013901492026492028, "best_score": 0.0016072584822584824, "new_best_model": false, "precision": 0.0006458333333333337, "recall": 0.00625, "f2": 0.0013901492026492028, "mrr": 0.0015205026455026457}, "13": {"lr": [0.00017340025990345064, 0.00017340025990345064], "train_loss": 5.654173374176025, "total": 5.654173579901756, "contrastive_loss": 4.653814717678721, "triplet_loss": 0.1913670568561873, "val_score": 0.0008313330188330187, "best_score": 0.0016072584822584824, "new_best_model": false, "precision": 0.0003958333333333335, "recall": 0.003958333333333334, "f2": 0.0008313330188330187, "mrr": 0.000707010582010582}, "14": {"lr": [0.00013722937031498307, 0.00013722937031498307], "train_loss": 5.654489517211914, "total": 5.654489574623746, "contrastive_loss": 4.653799203725962, "triplet_loss": 0.1913670568561873, "val_score": 0.0015971875346875346, "best_score": 0.0016072584822584824, "new_best_model": false, "precision": 0.0007291666666666671, "recall": 0.006875, "f2": 0.0015971875346875346, "mrr": 0.002682787698412698}, "15": {"lr": [0.00010384757955302797, 0.00010384757955302797], "train_loss": 5.654581546783447, "total": 5.654581433554557, "contrastive_loss": 4.653864117370401, "triplet_loss": 0.19147157190635453, "val_score": 0.000866355866355866, "best_score": 0.0016072584822584824, "new_best_model": false, "precision": 0.00041666666666666675, "recall": 0.004166666666666667, "f2": 0.000866355866355866, "mrr": 0.0021487268518518518}, "16": {"lr": [7.40768580939564e-05, 7.40768580939564e-05], "train_loss": 5.654132843017578, "total": 5.654132753710284, "contrastive_loss": 4.653907801395276, "triplet_loss": 0.1913670568561873, "val_score": 0.000542119917119917, "best_score": 0.0016072584822584824, "new_best_model": false, "precision": 0.00025000000000000006, "recall": 0.0025, "f2": 0.000542119917119917, "mrr": 0.0012953869047619049}, "17": {"lr": [4.865025990345063e-05, 4.865025990345063e-05], "train_loss": 5.654181480407715, "total": 5.6541813368781355, "contrastive_loss": 4.653891062656773, "triplet_loss": 0.19147157190635453, "val_score": 0.0009113802863802865, "best_score": 0.0016072584822584824, "new_best_model": false, "precision": 0.0004166666666666667, "recall": 0.004166666666666667, "f2": 0.0009113802863802865, "mrr": 0.0019849537037037036}, "18": {"lr": [2.8193872215002235e-05, 2.8193872215002235e-05], "train_loss": 5.65456485748291, "total": 5.654565103077968, "contrastive_loss": 4.653933930157818, "triplet_loss": 0.19147157190635453, "val_score": 0.0007158813408813409, "best_score": 0.0016072584822584824, "new_best_model": false, "precision": 0.00029166666666666675, "recall": 0.002916666666666667, "f2": 0.0007158813408813409, "mrr": 0.0012632275132275135}}
5_lr_up_lr_6/r1s/5_lr_up_lr_6_s26092004_f0_r1_vs0.00161_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b75a09ed812dd5994be37e3ece925da21c66a9d2aa3f3579ec254bdf8b1ce5e4
3
+ size 546478464
5_lr_up_lr_6/results/5_lr_up_lr_6_test.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Best model": {
3
+ "5": {
4
+ "precision": 0.00030120481927710846,
5
+ "recall": 0.0015060240963855422,
6
+ "f2": 0.00046890860143872194,
7
+ "mrr": 0.00048318273092369475
8
+ },
9
+ "10": {
10
+ "precision": 0.0002635542168674699,
11
+ "recall": 0.0026355421686746986,
12
+ "f2": 0.0006595855616939954,
13
+ "mrr": 0.0006557479919678715
14
+ },
15
+ "15": {
16
+ "precision": 0.0002635542168674699,
17
+ "recall": 0.0026355421686746986,
18
+ "f2": 0.0006595855616939954,
19
+ "mrr": 0.0006557479919678715
20
+ }
21
+ },
22
+ "Last model": {
23
+ "5": {
24
+ "precision": 0.0004518072289156626,
25
+ "recall": 0.002259036144578313,
26
+ "f2": 0.0008085455374612,
27
+ "mrr": 0.0006337851405622489
28
+ },
29
+ "10": {
30
+ "precision": 0.000602409638554217,
31
+ "recall": 0.006024096385542169,
32
+ "f2": 0.0014563331581403868,
33
+ "mrr": 0.0011450564161407533
34
+ },
35
+ "15": {
36
+ "precision": 0.000602409638554217,
37
+ "recall": 0.006024096385542169,
38
+ "f2": 0.0014563331581403868,
39
+ "mrr": 0.0011450564161407533
40
+ }
41
+ }
42
+ }
5_lr_up_lr_6/results/5_lr_up_lr_6_test_df.xlsx ADDED
Binary file (5.38 kB). View file
 
5_lr_up_lr_6/results/5_lr_up_lr_6_test_df_best.xlsx ADDED
Binary file (5.38 kB). View file