SS3M commited on
Commit
b130d8a
·
verified ·
1 Parent(s): 7659e50

Upload 15_40_negs_19's state dict

Browse files
.gitattributes CHANGED
@@ -40,3 +40,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
40
  10_lr_attn_chunk_pool_instead_mean_12/logs/10_lr_attn_chunk_pool_instead_mean_12_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
41
  11_12_lr_14/logs/11_12_lr_14_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
42
  11_12_13_lr_17/logs/11_12_13_lr_17_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
 
 
40
  10_lr_attn_chunk_pool_instead_mean_12/logs/10_lr_attn_chunk_pool_instead_mean_12_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
41
  11_12_lr_14/logs/11_12_lr_14_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
42
  11_12_13_lr_17/logs/11_12_13_lr_17_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
43
+ 15_40_negs_19/logs/15_40_negs_19_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
15_40_negs_19/15_40_negs_19.py ADDED
@@ -0,0 +1,1583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [code]
2
+ get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch]')
3
+
4
+ # %% [code]
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
12
+ import torch.nn.functional as F
13
+ import albumentations as albu
14
+ from transformers import AutoTokenizer, AutoModel
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from positional_encodings.torch_encodings import PositionalEncoding1D
18
+
19
+ from sklearn.metrics import f1_score
20
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
21
+ from scipy.spatial.transform import Rotation as R
22
+ from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold
23
+ from sklearn.metrics import precision_recall_fscore_support
24
+ from timm.utils import ModelEmaV3
25
+ import timm
26
+
27
+ import os
28
+ import gc
29
+ import json
30
+ from pathlib import Path
31
+ import pickle
32
+ from tqdm.auto import tqdm
33
+ import copy
34
+ import numpy as np
35
+ import pandas as pd
36
+ import polars as pl
37
+ from PIL import Image
38
+ import time
39
+ from tqdm import tqdm
40
+ from matplotlib import pyplot as plt
41
+ import seaborn as sns
42
+ from multiprocessing import Manager as MemoryManager
43
+ from functools import lru_cache
44
+ import shutil
45
+ import glob
46
+ import cv2
47
+ import random
48
+ import re
49
+ import joblib
50
+ import math
51
+ from huggingface_hub import HfApi, snapshot_download
52
+ import evaluate
53
+ from underthesea import word_tokenize as vi_tokenize_tool
54
+ import spacy
55
+ en_tokenize_tool = spacy.load("en_core_web_sm")
56
+ from collections import defaultdict, Counter
57
+
58
+ # %% [code]
59
+ file_path = "/kaggle/input/notebooks/sambui22022517/kltn-lr-bm25/bm25_scores.npy"
60
+
61
+ bm25_scores = np.load(file_path, allow_pickle=True)
62
+ print(bm25_scores.shape)
63
+
64
+ # %% [code]
65
+ # Global config
66
+ SEEDS = [26092004]
67
+ topk = 1
68
+ nfolds = 5
69
+ only_fold_idx = None
70
+ test_only = 0
71
+ debug_only = 0
72
+
73
+ # Config thư mục
74
+ dataset = 'kltn/raw' # vhe, bkee, casie, kltn/only_issues, kltn/only_actions, kltn/raw
75
+ root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' ## Thư mục chứa file train, val, test
76
+ train_dir = f'{root_dir}'
77
+ # val_dir = f'{root_dir}/val'
78
+ test_dir = f'{root_dir}'
79
+
80
+ # Config checkpoints
81
+
82
+ # Config training
83
+ epochs = 18 if not debug_only else 2
84
+ batch_size = 32
85
+ device = "cuda" if torch.cuda.is_available() else "cpu"
86
+ # # Thêm biến toàn cục nào đó vào đây
87
+ repo_name = 'SS3M/kltn-lr-experiments'
88
+ state_dict_save_name = "15_40_negs_19"
89
+ checkpoints_dir = state_dict_save_name
90
+ pretrained_dir = "/kaggle/working"
91
+ os.makedirs(f'{checkpoints_dir}', exist_ok=True)
92
+
93
+ backbone_model_name = "bert-base-uncased" if dataset == "casie" else "vinai/phobert-base"
94
+ word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == "casie" else vi_tokenize_tool(text)
95
+ max_len_dict = {
96
+ 'kltn/raw': 256,
97
+ 'kltn/only_issues': 52,
98
+ 'kltn/only_actions': 69,
99
+ 'vhe': 51,
100
+ 'bkee': 62,
101
+ 'casie': 40,
102
+ }
103
+ zero_events_rate_dict = {
104
+ 'kltn/raw': 1000,
105
+ 'kltn/only_issues': 0,
106
+ 'kltn/only_actions': 0.2,
107
+ 'vhe': 1000, # mean keep all zero-events samples
108
+ 'bkee': 1000,
109
+ 'casie': 1000,
110
+ }
111
+
112
+ max_len = max_len_dict[dataset]
113
+ max_n_parts = 2
114
+ max_span_len = 14
115
+ n_negs = 5 * 8
116
+ zero_events_rate = zero_events_rate_dict[dataset]
117
+
118
+ # Trainer
119
+ trainer_params = {
120
+ "training_time": "00:11:30:00",
121
+ "eval_mode": "max",
122
+ "topk": topk,
123
+ "save_name": state_dict_save_name,
124
+ "save_best": True,
125
+ "save_last": True,
126
+ "device": device,
127
+ "logging": True,
128
+ "logging_file": True,
129
+ "checkpoints_dir": checkpoints_dir,
130
+ "early_stopping": 30,
131
+ "eval_from_ratio": 0.4,
132
+ "eval_every": 1,
133
+ "schedule_in_step": False,
134
+ "use_ema": True,
135
+ "ema_from_ratio": 0.3,
136
+ "ema_decay": 0.9995,
137
+ "max_grad_norm": 200.0,
138
+ "return_best": True,
139
+ "return_last": True,
140
+ }
141
+
142
+ # Memory
143
+ train_memory_params = {
144
+ 'max_len': max_len,
145
+ 'max_n_parts': max_n_parts,
146
+ 'n_negs': n_negs,
147
+ }
148
+ val_memory_params = {
149
+ 'max_len': max_len,
150
+ 'max_n_parts': max_n_parts,
151
+ 'n_negs': n_negs,
152
+ }
153
+ corpus_memory_params = {
154
+ 'max_len': max_len,
155
+ 'max_n_parts': max_n_parts,
156
+ }
157
+
158
+ # Data Loader
159
+ def seed_worker(worker_id):
160
+ worker_seed = torch.initial_seed() % 2**32
161
+ np.random.seed(worker_seed)
162
+ random.seed(worker_seed)
163
+
164
+ train_loader_params = {
165
+ 'batch_size': batch_size,
166
+ 'shuffle': True,
167
+ 'pin_memory':True,
168
+ 'num_workers': 2,
169
+ 'drop_last': False,
170
+ 'worker_init_fn': seed_worker,
171
+ 'persistent_workers': False,
172
+ }
173
+ val_loader_params = {
174
+ 'batch_size': batch_size,
175
+ 'shuffle': False,
176
+ 'pin_memory':True,
177
+ 'num_workers': 1,
178
+ 'drop_last': False,
179
+ 'worker_init_fn': seed_worker,
180
+ 'persistent_workers': False,
181
+ }
182
+
183
+ # Model
184
+ model_params = {
185
+ 'backbone_name': backbone_model_name,
186
+ 'projection_dim': 256,
187
+ 'normalize': True,
188
+ }
189
+
190
+ # Loss Func
191
+ loss_func_params = {
192
+ 'lambda_contrastive': 1.0,
193
+ 'lambda_triplet': 0.5,
194
+ }
195
+ eval_func_params = {}
196
+
197
+ # Optim
198
+ optim_params = {
199
+ 'name': 'AdamW',
200
+ 'lr': 1e-4,
201
+ 'weight_decay': 1e-4,
202
+ }
203
+ scheduler_params = {
204
+ 'name': 'CosineAnnealingLR',
205
+ 'T_max': 20, # Số epoch để hoàn thành một chu kỳ giảm LR
206
+ 'eta_min': 1e-6 # Learning rate nhỏ nhất trong chu kỳ
207
+ }
208
+
209
+ # %% [code]
210
+ def set_seed(seed=42):
211
+ random.seed(seed)
212
+ np.random.seed(seed)
213
+ torch.manual_seed(seed)
214
+ torch.cuda.manual_seed(seed)
215
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
216
+ torch.use_deterministic_algorithms(False)
217
+ torch.backends.cudnn.deterministic = True
218
+ torch.backends.cudnn.benchmark = False
219
+ os.environ['PYTHONHASHSEED'] = str(seed)
220
+
221
+ # %% [code]
222
+ class CustomLoss(nn.Module):
223
+ def __init__(
224
+ self,
225
+ temperature=0.05,
226
+ margin=0.2,
227
+ lambda_contrastive=1.0,
228
+ lambda_triplet=0.5,
229
+ ):
230
+ super().__init__()
231
+
232
+ self.temperature = temperature
233
+ self.margin = margin
234
+
235
+ self.lambda_contrastive = lambda_contrastive
236
+ self.lambda_triplet = lambda_triplet
237
+
238
+ def forward(
239
+ self,
240
+ encoded_text,
241
+ encoded_pos,
242
+ encoded_neg,
243
+ pos_mask
244
+ ):
245
+ loss_contrastive = self.multi_pos_contrastive_loss(encoded_text, encoded_pos, encoded_neg, pos_mask)
246
+ loss_triplet = self.hardest_triplet_loss(encoded_text, encoded_pos, encoded_neg, pos_mask)
247
+
248
+ total_loss = (
249
+ self.lambda_contrastive * loss_contrastive +
250
+ self.lambda_triplet * loss_triplet
251
+ )
252
+
253
+ return {
254
+ "total": total_loss,
255
+ "contrastive_loss": loss_contrastive,
256
+ "triplet_loss": loss_triplet,
257
+ }
258
+
259
+ def multi_pos_contrastive_loss(self, q, pos, neg, pos_mask):
260
+ B, P, D = pos.shape
261
+ N = neg.shape[1]
262
+
263
+ # ===== concat docs =====
264
+ docs = torch.cat([pos, neg], dim=1) # [B, P+N, D]
265
+
266
+ # ===== similarity =====
267
+ logits = torch.matmul(q.unsqueeze(1), docs.transpose(1, 2)).squeeze(1)
268
+ logits = logits / self.temperature # [B, P+N]
269
+
270
+ # ===== labels =====
271
+ labels = torch.zeros_like(logits)
272
+ labels[:, :P] = pos_mask # chỉ pos hợp lệ
273
+
274
+ # ===== log-softmax =====
275
+ log_prob = logits - torch.logsumexp(logits, dim=1, keepdim=True)
276
+
277
+ # ===== normalize theo số pos thật =====
278
+ pos_count = pos_mask.sum(dim=1).clamp(min=1)
279
+
280
+ loss = -(labels * log_prob).sum(dim=1) / pos_count
281
+
282
+ return loss.mean()
283
+
284
+ def hardest_triplet_loss(self, q, pos, neg, pos_mask):
285
+ # ===== similarity =====
286
+ pos_sim = torch.matmul(q.unsqueeze(1), pos.transpose(1, 2)).squeeze(1) # [B, P]
287
+ neg_sim = torch.matmul(q.unsqueeze(1), neg.transpose(1, 2)).squeeze(1) # [B, N]
288
+
289
+ # ===== mask pos =====
290
+ pos_sim_masked = pos_sim.clone()
291
+ pos_sim_masked[pos_mask == 0] = float('inf') # loại pad
292
+
293
+ # ===== hardest =====
294
+ hardest_pos = pos_sim_masked.min(dim=1).values
295
+ hardest_neg = neg_sim.max(dim=1).values
296
+
297
+ # ===== loss =====
298
+ loss = F.relu(self.margin + hardest_neg - hardest_pos)
299
+
300
+ return loss.mean()
301
+
302
+ # %% [code]
303
+ class CustomEvalFn(nn.Module):
304
+ def __init__(self):
305
+ super().__init__()
306
+
307
+ def forward(self, pred_topk, real_topk):
308
+ """
309
+ pred_topk: List[List[int]] shape [B, K]
310
+ real_topk: List[List[int]] shape [B, Ki]
311
+ """
312
+
313
+ B = len(pred_topk)
314
+
315
+ total_recall = 0.0
316
+ total_map = 0.0
317
+ total_mrp = 0.0
318
+
319
+ for i in range(B):
320
+ preds = pred_topk[i]
321
+ gts = set(real_topk[i])
322
+
323
+ # ===== Recall@K =====
324
+ hit = any(p in gts for p in preds)
325
+ total_recall += 1.0 if hit else 0.0
326
+
327
+ # ===== AP =====
328
+ num_hits = 0
329
+ ap = 0.0
330
+
331
+ for rank, p in enumerate(preds, start=1):
332
+ if p in gts:
333
+ num_hits += 1
334
+ precision_at_rank = num_hits / rank
335
+ ap += precision_at_rank
336
+
337
+ if len(gts) > 0:
338
+ ap /= len(gts)
339
+
340
+ total_map += ap
341
+
342
+ # ===== R-Precision =====
343
+ r = len(gts)
344
+
345
+ if r > 0:
346
+ top_r = preds[:r]
347
+
348
+ tp_r = sum(p in gts for p in top_r)
349
+
350
+ rp = tp_r / r
351
+ else:
352
+ rp = 0.0
353
+
354
+ total_mrp += rp
355
+
356
+ recall = total_recall / B
357
+ mAP = total_map / B
358
+ mRP = total_mrp / B
359
+
360
+ return {
361
+ "recall": recall,
362
+ "mAP": mAP,
363
+ "mRP": mRP,
364
+ }
365
+
366
+ # %% [code]
367
+ class EncodeModel(nn.Module):
368
+ def __init__(
369
+ self,
370
+ backbone_name,
371
+ projection_dim,
372
+ normalize=True,
373
+ dropout=0.1
374
+ ):
375
+ super().__init__()
376
+
377
+ self.encoder = AutoModel.from_pretrained(backbone_name)
378
+
379
+ hidden_size = self.encoder.config.hidden_size
380
+
381
+ # mạnh hơn single linear
382
+ self.proj = nn.Sequential(
383
+ nn.Linear(hidden_size, hidden_size),
384
+ nn.GELU(),
385
+ nn.Dropout(dropout),
386
+ nn.Linear(hidden_size, projection_dim)
387
+ )
388
+
389
+ # residual projection
390
+ self.residual_proj = (
391
+ nn.Identity()
392
+ if hidden_size == projection_dim
393
+ else nn.Linear(hidden_size, projection_dim)
394
+ )
395
+
396
+ self.normalize = normalize
397
+
398
+ def mean_pooling(self, hidden, attention_mask):
399
+ mask = attention_mask.unsqueeze(-1).float() # [N, L, 1]
400
+ pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-6)
401
+
402
+ return pooled
403
+
404
+ def forward(self, input_ids, attention_mask, is_query=True):
405
+
406
+ if is_query:
407
+ # [B, n_parts, L]
408
+ B, n_parts, L = input_ids.shape
409
+
410
+ input_ids = input_ids.view(-1, L)
411
+ attention_mask = attention_mask.view(-1, L)
412
+
413
+ outputs = self.encoder(
414
+ input_ids=input_ids,
415
+ attention_mask=attention_mask
416
+ )
417
+
418
+ hidden = outputs.last_hidden_state # [B*n_parts, L, H]
419
+
420
+ chunk_repr = self.mean_pooling(hidden, attention_mask) # [B*n_parts, H]
421
+ chunk_repr = chunk_repr.view(B, n_parts, -1) # [B, n_parts, H]
422
+ pooled = chunk_repr.mean(dim=1)
423
+ # [B, H]
424
+
425
+ else:
426
+ # [B, K, n_parts, L]
427
+ B, K, n_parts, L = input_ids.shape
428
+
429
+ input_ids = input_ids.view(-1, L)
430
+ attention_mask = attention_mask.view(-1, L)
431
+
432
+ outputs = self.encoder(
433
+ input_ids=input_ids,
434
+ attention_mask=attention_mask
435
+ )
436
+
437
+ hidden = outputs.last_hidden_state
438
+
439
+ chunk_repr = self.mean_pooling(hidden, attention_mask) # [B*K*n_parts, H]
440
+ chunk_repr = chunk_repr.view(B, K, n_parts, -1) # [B, K, n_parts, H]
441
+ pooled = chunk_repr.mean(dim=2) # [B, K, H]
442
+
443
+ # residual MLP projection
444
+ emb = self.proj(pooled) + self.residual_proj(pooled)
445
+
446
+ if self.normalize:
447
+ emb = F.normalize(emb, dim=-1)
448
+
449
+ return emb
450
+
451
+ def test_model():
452
+ model = nn.DataParallel(EncodeModel('vinai/phobert-base', 256, True)).to(device)
453
+ model.eval()
454
+
455
+ bz = 32
456
+ vocab_size = 1000
457
+ qi = torch.randint(0, vocab_size, (bz, 1, 256)).to(device)
458
+ qa = torch.ones(bz, 1, 256).to(device)
459
+ di = torch.randint(0, vocab_size, (bz, 5, 2, 256)).to(device)
460
+ da = torch.ones(bz, 5, 2, 256).to(device)
461
+
462
+ st = time.time()
463
+ with torch.no_grad():
464
+ encoded_text = model(qi, qa, is_query=True)
465
+ encoded_pos = model(di, da, is_query=False)
466
+ encoded_neg = model(di, da, is_query=False)
467
+ print(encoded_text.shape, encoded_pos.shape, encoded_neg.shape)
468
+ print(time.time() - st)
469
+
470
+ del model, qi, qa, di, da, encoded_text, encoded_pos, encoded_neg
471
+ torch.cuda.empty_cache()
472
+ gc.collect()
473
+ test_model()
474
+
475
+ # %% [code]
476
+ def configure_optimizers(network, optim_params, scheduler_params):
477
+ try:
478
+ optim_params = copy.copy(optim_params)
479
+ scheduler_params = copy.copy(scheduler_params)
480
+
481
+ optim_name = optim_params.pop('name')
482
+ scheduler_name = scheduler_params.pop('name')
483
+
484
+ optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None)
485
+ scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None)
486
+
487
+ if optimizer_cls is None:
488
+ raise ValueError(f"Optimizer '{optim_name}' is not available!")
489
+
490
+ optimizer = optimizer_cls(network.parameters(), **optim_params)
491
+
492
+ scheduler = None
493
+ if scheduler_params and scheduler_cls: # Chỉ tạo scheduler nếu có tham số
494
+ scheduler = scheduler_cls(optimizer, **scheduler_params)
495
+
496
+ return optimizer, scheduler
497
+
498
+ except KeyError as e:
499
+ raise ValueError(f"Missing {e} in config!!")
500
+
501
+ def freeze(self, model):
502
+ model.eval()
503
+ for param in model.parameters():
504
+ param.requires_grad = False
505
+
506
+ def unfreeze(self, model):
507
+ model.train()
508
+ for param in model.parameters():
509
+ param.requires_grad = True
510
+
511
+ def reduce_batch_size(loader, ratio=0.5):
512
+ new_bs = max(1, int(loader.batch_size * ratio))
513
+
514
+ shuffle = isinstance(loader.sampler, RandomSampler)
515
+
516
+ new_loader = DataLoader(
517
+ dataset=loader.dataset,
518
+ batch_size=new_bs,
519
+ shuffle=shuffle,
520
+ sampler=None if shuffle else loader.sampler,
521
+ num_workers=loader.num_workers,
522
+ collate_fn=loader.collate_fn,
523
+ pin_memory=loader.pin_memory,
524
+ drop_last=loader.drop_last,
525
+ timeout=loader.timeout,
526
+ worker_init_fn=loader.worker_init_fn,
527
+ multiprocessing_context=loader.multiprocessing_context,
528
+ generator=loader.generator,
529
+ prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None,
530
+ persistent_workers=loader.persistent_workers,
531
+ pin_memory_device=loader.pin_memory_device
532
+ )
533
+
534
+ return new_loader
535
+
536
+ def list_to_tuple(x):
537
+ if isinstance(x, (list, tuple)):
538
+ return tuple(list_to_tuple(i) for i in x)
539
+ return x
540
+
541
+ def fmt(x):
542
+ if isinstance(x, float):
543
+ return round(x, 5)
544
+ if isinstance(x, dict):
545
+ return {k: fmt(v) for k, v in x.items()}
546
+ if isinstance(x, list):
547
+ return [fmt(v) for v in x]
548
+ return x
549
+
550
+ class ModelEmaV3Proxy(ModelEmaV3):
551
+ def __getattr__(self, name):
552
+ try:
553
+ return super().__getattr__(name)
554
+ except AttributeError:
555
+ return getattr(self.module, name)
556
+
557
+ class DataParallelProxy(nn.DataParallel):
558
+ def __getattr__(self, name):
559
+ try:
560
+ return super().__getattr__(name)
561
+ except AttributeError:
562
+ attr = getattr(self.module, name)
563
+
564
+ if callable(attr):
565
+ def wrapper(*args, **kwargs):
566
+ return self._parallel_apply_method(name, *args, **kwargs)
567
+ return wrapper
568
+
569
+ return attr
570
+
571
+ def _parallel_apply_method(self, method_name, *inputs, **kwargs):
572
+ if not self.device_ids:
573
+ return getattr(self.module, method_name)(*inputs, **kwargs)
574
+
575
+ inputs_scattered, kwargs_scattered = self.scatter(inputs, kwargs, self.device_ids)
576
+
577
+ replicas = self.replicate(self.module, self.device_ids)
578
+
579
+ outputs = self.parallel_apply(
580
+ [getattr(replica, method_name) for replica in replicas],
581
+ inputs_scattered,
582
+ kwargs_scattered
583
+ )
584
+
585
+ return self.gather(outputs, self.output_device)
586
+
587
+ class Trainer:
588
+ def __init__(
589
+ self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0,
590
+ logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu',
591
+ schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True
592
+ ):
593
+ self.ema_net = None
594
+
595
+ self.training_time = self._time_str_to_seconds(training_time)
596
+ self.mode = eval_mode
597
+ self.topk = topk
598
+ self.device = device
599
+ self.logging = logging if logging < epochs else 1
600
+ self.logging_file = logging_file
601
+ self.checkpoints_dir = checkpoints_dir
602
+ self.early_stopping = early_stopping
603
+ self.eval_from_ratio = eval_from_ratio
604
+ self.eval_every = eval_every
605
+ self.save_name = save_name
606
+ self.save_best = save_best
607
+ self.save_last = save_last
608
+ self.return_best = return_best
609
+ self.return_last = return_last
610
+ self.max_grad_norm = max_grad_norm
611
+ self.schedule_in_step = schedule_in_step
612
+ self.use_ema = use_ema
613
+ self.ema_from_ratio = ema_from_ratio
614
+ self.ema_decay = ema_decay
615
+
616
+ self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]]
617
+ self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0)
618
+
619
+ def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, corpus_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, refresh_every=3):
620
+ if eval_fn is None:
621
+ if self.mode == "max":
622
+ eval_fn = lambda *x: -loss_fn(*x)
623
+ else:
624
+ eval_fn = lambda *x: loss_fn(*x)
625
+
626
+ if torch.cuda.device_count() > 1:
627
+ network = DataParallelProxy(network)
628
+ network = network.to(self.device)
629
+
630
+ if not start_training_time:
631
+ start_training_time = time.time()
632
+
633
+ start_ema = int(epochs * self.ema_from_ratio)
634
+ start_eval = int(epochs * self.eval_from_ratio)
635
+
636
+ if val_loader is None:
637
+ print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!')
638
+ else:
639
+ model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc'
640
+ start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else ''
641
+ print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str)
642
+
643
+ training_log = {}
644
+ for epoch in range(start_epoch, epochs+start_epoch):
645
+ if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema:
646
+ self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
647
+
648
+ try:
649
+ eval_net = self.ema_net if (self.use_ema and self.ema_net is not None) else network
650
+ if (epoch - start_epoch) % refresh_every == 0:
651
+ encoded_docs = self._get_encoded_docs(eval_net, corpus_loader)
652
+ print(f"[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Refresh Encoded Doc (refresh_every={refresh_every})!")
653
+
654
+ train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, encoded_docs)
655
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
656
+ logging_dict.update(train_loss_epoch_dict)
657
+
658
+ if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0:
659
+ val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, encoded_docs)
660
+ update = self._update_best_network(eval_net, val_score, epoch)
661
+ logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update})
662
+ logging_dict.update(val_score_dict)
663
+ if not self.schedule_in_step and scheduler:
664
+ scheduler.step()
665
+
666
+ except RuntimeError as e:
667
+ if "out of memory" in str(e).lower():
668
+ print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...")
669
+ torch.cuda.empty_cache()
670
+ gc.collect()
671
+ if torch.cuda.is_available():
672
+ torch.cuda.synchronize()
673
+ print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared")
674
+
675
+ train_loader = reduce_batch_size(train_loader, ratio=0.5)
676
+ if val_loader is not None:
677
+ val_loader = reduce_batch_size(val_loader, ratio=0.5)
678
+
679
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')}
680
+ else:
681
+ raise
682
+
683
+ training_log[epoch] = logging_dict
684
+ if self.is_early_stopping(epoch):
685
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...')
686
+ break
687
+ if self.logging:
688
+ if epoch % self.logging == 0:
689
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict))
690
+ else:
691
+ print(f'{epoch}...', end=' ')
692
+
693
+ if self._at_time_limit(start_training_time):
694
+ print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}')
695
+ break
696
+
697
+ if self.logging_file:
698
+ os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True)
699
+ with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f:
700
+ f.write(json.dumps(training_log))
701
+
702
+ if self.use_ema and self.ema_net is not None:
703
+ self._save_state_dict(self.ema_net.module)
704
+ else:
705
+ self._save_state_dict(network)
706
+ print(f'[Trainer CallBack] 📢 Kết thúc training.\n')
707
+
708
+ best_model, last_model = None, None
709
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
710
+ if self.return_best :
711
+ best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict()
712
+ best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()}
713
+ if self.return_last:
714
+ last_model = eval_net.state_dict()
715
+ last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()}
716
+
717
+ del network
718
+ torch.cuda.empty_cache()
719
+ gc.collect()
720
+ return training_log, best_model, last_model
721
+
722
+ def _time_str_to_seconds(self, time_str):
723
+ days, hours, minutes, seconds = map(int, time_str.split(":"))
724
+ return days * 86400 + hours * 3600 + minutes * 60 + seconds
725
+
726
+ def _update_best_network(self, network, val_score, epoch):
727
+ topk = max(1, self.topk)
728
+ self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}])
729
+ self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk]
730
+ if val_score in [x[0] for x in self.best_stage]:
731
+ return True
732
+ return False
733
+
734
+ def is_early_stopping(self, epoch):
735
+ if self.best_stage[0][1] is None:
736
+ return False
737
+ if not self.early_stopping:
738
+ return False
739
+ return epoch - self.best_stage[0][1] >= self.early_stopping
740
+
741
+ def _at_time_limit(self, start_training_time):
742
+ return time.time() - start_training_time >= self.training_time
743
+
744
+ def _save_state_dict(self, network):
745
+ if self.topk <= 0:
746
+ return
747
+
748
+ if self.save_best:
749
+ for r in range(self.topk):
750
+ os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True)
751
+
752
+ for rank, (score, epoch, state_dict) in enumerate(self.best_stage):
753
+ if state_dict is None:
754
+ continue
755
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()}
756
+ torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth')
757
+ if self.save_last:
758
+ os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True)
759
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()}
760
+ torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth')
761
+
762
+ def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn, encoded_docs):
763
+ network.train()
764
+ total_loss = 0
765
+ total_loss_dict = {}
766
+ for batch_idx, batch in enumerate(train_loader):
767
+ optimizer.zero_grad()
768
+ with torch.autocast(device_type=self.device, dtype=torch.float16):
769
+ loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn, encoded_docs)
770
+
771
+ for k, v in loss_dict.items():
772
+ t = total_loss_dict.get(k, 0)
773
+ total_loss_dict[k] = t + v
774
+ self.grad_scaler.scale(loss).backward()
775
+ self.grad_scaler.unscale_(optimizer)
776
+ grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm)
777
+ # print(grad_norm) # Bỏ cmt dòng này để biết nên chọn max_grad_norm bằng bao nhiêu...
778
+ self.grad_scaler.step(optimizer)
779
+ self.grad_scaler.update()
780
+ if self.schedule_in_step and scheduler:
781
+ scheduler.step()
782
+ if self.use_ema and self.ema_net is not None:
783
+ self.ema_net.update(network)
784
+ total_loss += loss
785
+ return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()}
786
+
787
+ def _eval_epoch(self, network, val_loader, eval_fn, encoded_docs):
788
+ network.eval()
789
+ total_score = 0.0
790
+ total_score_dict = {}
791
+ object_lists = None # sẽ init sau
792
+
793
+ with torch.no_grad():
794
+ for batch_idx, batch in enumerate(val_loader):
795
+ score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, encoded_docs)
796
+ total_score += score
797
+
798
+ for k, v in score_dict.items():
799
+ t = total_score_dict.get(k, 0)
800
+ total_score_dict[k] = t + v
801
+
802
+ if objects:
803
+ if object_lists is None:
804
+ object_lists = [[] for _ in range(len(objects))]
805
+
806
+ for i, obj in enumerate(objects):
807
+ object_lists[i].append(obj.detach())
808
+
809
+ if object_lists is not None:
810
+ object_arrays = [
811
+ torch.concat(obj_list, dim=0).cpu().numpy()
812
+ for obj_list in object_lists
813
+ ]
814
+ else:
815
+ object_arrays = []
816
+
817
+ return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays
818
+
819
+ def _get_encoded_docs(self, network, corpus_loader):
820
+ network.eval()
821
+ with torch.no_grad():
822
+ encoded_docs = []
823
+ for batch_idx, batch in enumerate(corpus_loader):
824
+ input_ids = batch['input_ids'].to(self.device)
825
+ attn_mask = batch['attn_mask'].to(self.device)
826
+ encoded_doc = network(input_ids, attn_mask, is_query=False)
827
+ encoded_docs.append(encoded_doc)
828
+ encoded_docs = torch.concat(encoded_docs, dim=0).squeeze(1)
829
+ return encoded_docs
830
+
831
+ def _cal_loss(self, network, batch, batch_idx, loss_fn, encoded_docs):
832
+ # Bạn cần override _cal_loss để tính loss
833
+ text_input_ids = batch['text_input_ids'].to(self.device)
834
+ text_attn_mask = batch['text_attn_mask'].to(self.device)
835
+ pos_idxes = batch['pos_idxes'].to(self.device)
836
+ pos_mask = batch['pos_mask'].to(self.device)
837
+ neg_idxes = batch['neg_idxes'].to(self.device)
838
+
839
+ encoded_text = network(text_input_ids, text_attn_mask, is_query=True)
840
+ encoded_pos = encoded_docs[pos_idxes]
841
+ encoded_neg = encoded_docs[neg_idxes]
842
+
843
+ loss_dict = loss_fn(encoded_text, encoded_pos, encoded_neg, pos_mask)
844
+ return loss_dict['total'], loss_dict
845
+
846
+ def _cal_val_score(self, network, batch, batch_idx, eval_fn, encoded_docs):
847
+ # Bạn cần override _cal_val_score để tính val score, list bên cạnh là để trả về y hay pred gì đó (nếu cần)
848
+ text_input_ids = batch['text_input_ids'].to(self.device)
849
+ text_attn_mask = batch['text_attn_mask'].to(self.device)
850
+ gt_pos_idxes = batch['gt_pos_idxes']
851
+ encoded_text = network(text_input_ids, text_attn_mask, is_query=True)
852
+
853
+ scores = torch.matmul(encoded_text, encoded_docs.T)
854
+ topk_scores, topk_indices = torch.topk(scores, k=10)
855
+ pred_topk = [
856
+ idx[score > 0].tolist()
857
+ for score, idx in zip(topk_scores, topk_indices)
858
+ ]
859
+
860
+ pred_topk = list_to_tuple(pred_topk)
861
+ gt_pos_idxes = list_to_tuple(gt_pos_idxes)
862
+ score_dict = eval_fn(pred_topk, gt_pos_idxes)
863
+ return score_dict['recall'], score_dict, []
864
+
865
+ # %% [code]
866
+ def tokenize_to_parts(text, tokenizer, max_len, max_n_parts):
867
+ # Tokenize với overflow để chia thành nhiều đoạn
868
+ enc = tokenizer(
869
+ text,
870
+ max_length=max_len*max_n_parts,
871
+ truncation=True,
872
+ padding="max_length",
873
+ return_overflowing_tokens=True,
874
+ return_tensors="pt"
875
+ )
876
+
877
+ input_ids = enc["input_ids"].reshape(max_n_parts, max_len) # (n_parts, max_len)
878
+ attn_mask = enc["attention_mask"].reshape(max_n_parts, max_len) # (n_parts, max_len)
879
+
880
+ return input_ids, attn_mask
881
+
882
+ class LawRetrievalDataset(Dataset):
883
+ def __init__(self, all_data, using_idxes, corpus_dict, tokenizer, max_len, max_n_parts, n_negs):
884
+ super().__init__()
885
+
886
+ self.all_data = all_data
887
+ self.using_idxes = using_idxes
888
+ self.tokenizer = tokenizer
889
+ self.max_len = max_len
890
+ self.max_n_parts = max_n_parts
891
+ self.n_negs = n_negs
892
+
893
+ # ===== BUILD CORPUS =====
894
+ idx = 0
895
+ self.corpus_list = []
896
+ self.corpus_dict = {}
897
+
898
+ for doc_name, articles_dict in corpus_dict.items():
899
+ self.corpus_dict[doc_name] = {}
900
+ for article_idx, content in articles_dict.items():
901
+ self.corpus_list.append([doc_name, article_idx, content])
902
+ self.corpus_dict[doc_name][article_idx] = {
903
+ 'content': content,
904
+ 'idx': idx
905
+ }
906
+ idx += 1
907
+
908
+ def __len__(self):
909
+ return len(self.using_idxes)
910
+
911
+ # ===== ENCODE DOC =====
912
+ def _encode_contexts(self, idxes):
913
+ all_input_ids, all_attn_mask = [], []
914
+
915
+ for idx in idxes:
916
+ name, art, _ = self.corpus_list[idx]
917
+ corpus = self.corpus_dict[name][art]
918
+
919
+ if 'content_input_ids' in corpus:
920
+ content_input_ids = corpus['content_input_ids']
921
+ content_attn_mask = corpus['content_attn_mask']
922
+ else:
923
+ content = corpus['content']
924
+ content_input_ids, content_attn_mask = tokenize_to_parts(
925
+ content, self.tokenizer, self.max_len, self.max_n_parts
926
+ )
927
+ corpus['content_input_ids'] = content_input_ids
928
+ corpus['content_attn_mask'] = content_attn_mask
929
+
930
+ all_input_ids.append(content_input_ids)
931
+ all_attn_mask.append(content_attn_mask)
932
+
933
+ all_input_ids = torch.stack(all_input_ids)
934
+ all_attn_mask = torch.stack(all_attn_mask)
935
+
936
+ return all_input_ids, all_attn_mask
937
+
938
+ def __getitem__(self, idx):
939
+ ridx = self.using_idxes[idx]
940
+ data = self.all_data[ridx]
941
+
942
+ query_text = data['text']
943
+
944
+ text_input_ids, text_attn_mask = tokenize_to_parts(
945
+ query_text, self.tokenizer, self.max_len, 1
946
+ )
947
+
948
+ # ===== POS =====
949
+ gt_pos_idxes = []
950
+ hard_names = []
951
+ for law in data['relevant_law']:
952
+ name = law['doc']
953
+ art = law['art']
954
+ gt_pos_idxes.append(self.corpus_dict[name][art]['idx'])
955
+ if name not in hard_names:
956
+ hard_names.append(name)
957
+
958
+ pos_idxes = torch.tensor(gt_pos_idxes, dtype=torch.long)
959
+ pos_mask = torch.ones(len(pos_idxes))
960
+
961
+ # ===== NEG =====
962
+ hard_neg_idxes = []
963
+ for name in hard_names:
964
+ for content in self.corpus_dict[name].values():
965
+ if content['idx'] in gt_pos_idxes:
966
+ continue
967
+ hard_neg_idxes.append(content['idx'])
968
+
969
+ easy_neg_idxes = list(range(len(self.corpus_list)))
970
+ for i in gt_pos_idxes + hard_neg_idxes:
971
+ if i in easy_neg_idxes:
972
+ easy_neg_idxes.remove(i)
973
+
974
+ n_hards = min(len(hard_neg_idxes), self.n_negs // 2)
975
+ neg_idxes = random.sample(hard_neg_idxes, n_hards) + random.sample(easy_neg_idxes, self.n_negs - n_hards)
976
+ neg_idxes = torch.tensor(neg_idxes, dtype=torch.long)
977
+
978
+ return {
979
+ 'text_input_ids': text_input_ids,
980
+ 'text_attn_mask': text_attn_mask,
981
+ 'gt_pos_idxes': gt_pos_idxes,
982
+ 'pos_idxes': pos_idxes,
983
+ 'pos_mask': pos_mask,
984
+ 'neg_idxes': neg_idxes,
985
+ }
986
+
987
+ class CorpusDataset(Dataset):
988
+ def __init__(self, corpus_dict, tokenizer, max_len, max_n_parts):
989
+ super().__init__()
990
+ self.tokenizer = tokenizer
991
+ self.max_len = max_len
992
+ self.max_n_parts = max_n_parts
993
+
994
+ idx = 0
995
+ self.corpus_list = []
996
+ self.corpus_dict = {}
997
+ for doc_name, articles_dict in corpus_dict.items():
998
+ self.corpus_dict[doc_name] = {}
999
+ for article_idx, content in articles_dict.items():
1000
+ self.corpus_list.append([doc_name, article_idx, content])
1001
+ self.corpus_dict[doc_name][article_idx] = {'content': content, 'idx': idx}
1002
+ idx += 1
1003
+
1004
+ def __len__(self):
1005
+ return len(self.corpus_list)
1006
+
1007
+ def _encode_contexts(self, idxes):
1008
+ all_input_ids, all_attn_mask = [], []
1009
+ for idx in idxes:
1010
+ name = self.corpus_list[idx][0]
1011
+ art = self.corpus_list[idx][1]
1012
+ corpus = self.corpus_dict[name][art]
1013
+ if 'content_input_ids' in corpus and 'content_attn_mask' in corpus:
1014
+ content_input_ids = corpus['content_input_ids']
1015
+ content_attn_mask = corpus['content_attn_mask']
1016
+ else:
1017
+ content = corpus['content']
1018
+ content_input_ids, content_attn_mask = tokenize_to_parts(content, self.tokenizer, self.max_len, self.max_n_parts)
1019
+ corpus['content_input_ids'] = content_input_ids
1020
+ corpus['content_attn_mask'] = content_attn_mask
1021
+
1022
+ all_input_ids.append(content_input_ids)
1023
+ all_attn_mask.append(content_attn_mask)
1024
+
1025
+ all_input_ids = torch.stack(all_input_ids)
1026
+ all_attn_mask = torch.stack(all_attn_mask)
1027
+ return all_input_ids, all_attn_mask
1028
+
1029
+ def __getitem__(self, idx):
1030
+ input_ids, attn_mask = self._encode_contexts([idx])
1031
+
1032
+ return {
1033
+ 'input_ids': input_ids,
1034
+ 'attn_mask': attn_mask,
1035
+ }
1036
+
1037
+ def _pad_batch(tensor_list, pad_value=0):
1038
+ """
1039
+ tensor_list: list of tensors, mỗi tensor shape (Nk, max_n_parts, max_len)
1040
+ return: tensor shape (B, max_Nk, max_n_parts, max_len)
1041
+ """
1042
+ max_Nk = max(t.size(0) for t in tensor_list)
1043
+
1044
+ padded = []
1045
+ for t in tensor_list:
1046
+ Nk = t.size(0)
1047
+
1048
+ if Nk < max_Nk:
1049
+ pad_shape = (max_Nk - Nk, *t.shape[1:])
1050
+ pad_tensor = t.new_full(pad_shape, pad_value)
1051
+ t = torch.cat([t, pad_tensor], dim=0)
1052
+
1053
+ padded.append(t)
1054
+
1055
+ return torch.stack(padded) # (B, max_Nk, max_n_parts, max_len)
1056
+
1057
+ def collate_fn(batch):
1058
+ text_input_ids = torch.stack([b["text_input_ids"] for b in batch])
1059
+ text_attn_mask = torch.stack([b["text_attn_mask"] for b in batch])
1060
+ gt_pos_idxes = [b["gt_pos_idxes"] for b in batch]
1061
+ neg_idxes = torch.stack([b["neg_idxes"] for b in batch])
1062
+
1063
+ pos_idxes = [b["pos_idxes"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1064
+ pos_mask = [b["pos_mask"].unsqueeze(-1).unsqueeze(-1) for b in batch]
1065
+
1066
+ # pad theo Nk
1067
+ pos_idxes = _pad_batch(pos_idxes, pad_value=0).squeeze(-1).squeeze(-1)
1068
+ pos_mask = _pad_batch(pos_mask, pad_value=0).squeeze(-1).squeeze(-1)
1069
+
1070
+ return {
1071
+ 'text_input_ids': text_input_ids,
1072
+ 'text_attn_mask': text_attn_mask,
1073
+ 'gt_pos_idxes': gt_pos_idxes,
1074
+ 'pos_idxes': pos_idxes,
1075
+ 'pos_mask': pos_mask,
1076
+ 'neg_idxes': neg_idxes,
1077
+ }
1078
+
1079
+ # %% [code]
1080
+ def encode_corpus(state_dicts, network, corpus_loader, device):
1081
+ if torch.cuda.device_count() > 1:
1082
+ network = nn.DataParallel(network)
1083
+ network.to(device)
1084
+ network.eval()
1085
+
1086
+ all_model_embs = []
1087
+ for i, state_dict in enumerate(state_dicts):
1088
+ # ===== load model =====
1089
+ if torch.cuda.device_count() > 1:
1090
+ network.module.load_state_dict(state_dict)
1091
+ else:
1092
+ network.load_state_dict(state_dict)
1093
+
1094
+ encoded_docs = []
1095
+
1096
+ with torch.no_grad():
1097
+ for batch in corpus_loader:
1098
+ input_ids = batch['input_ids'].to(device)
1099
+ attn_mask = batch['attn_mask'].to(device)
1100
+
1101
+ emb = network(input_ids, attn_mask, is_query=False) # [B, 1, D] hoặc [B, D]
1102
+
1103
+ encoded_docs.append(emb)
1104
+
1105
+ encoded_docs = torch.concat(encoded_docs, dim=0).squeeze(1) # [N, D]
1106
+ all_model_embs.append(encoded_docs)
1107
+
1108
+ # ===== ensemble =====
1109
+ # stack → [M, N, D]
1110
+ all_model_embs = torch.stack(all_model_embs, dim=0)
1111
+ final_embs = all_model_embs.mean(dim=0) # [N, D]
1112
+
1113
+ return final_embs
1114
+
1115
+ def test(state_dicts, network, test_loader, device, eval_fn, encoded_docs, bm25_scores, topks=[5, 10, 15]):
1116
+ if torch.cuda.device_count() > 1:
1117
+ network = nn.DataParallel(network)
1118
+ network.to(device)
1119
+ network.eval()
1120
+
1121
+ per_model_scores = []
1122
+ max_k = max(topks)
1123
+
1124
+ all_scores = []
1125
+ all_gt_pos_idxes = []
1126
+ with torch.no_grad():
1127
+ for batch in test_loader:
1128
+ text_input_ids = batch['text_input_ids'].to(device)
1129
+ text_attn_mask = batch['text_attn_mask'].to(device)
1130
+ gt_pos_idxes = batch['gt_pos_idxes']
1131
+ all_gt_pos_idxes.extend(gt_pos_idxes)
1132
+
1133
+ list_encoded_texts = []
1134
+
1135
+ for state_dict in state_dicts:
1136
+ # ===== load model =====
1137
+ if torch.cuda.device_count() > 1:
1138
+ network.module.load_state_dict(state_dict)
1139
+ else:
1140
+ network.load_state_dict(state_dict)
1141
+
1142
+ encoded_text = network(text_input_ids, text_attn_mask, is_query=True)
1143
+ list_encoded_texts.append(encoded_text)
1144
+
1145
+ ensemble_encoded_text = torch.stack(list_encoded_texts, dim=0).mean(dim=0)
1146
+ scores = torch.matmul(ensemble_encoded_text, encoded_docs.T) # B, M
1147
+ all_scores.append(scores)
1148
+
1149
+ all_scores = torch.concat(all_scores, dim=0) # (N, M)
1150
+
1151
+ bm25_scores = torch.tensor(bm25_scores).to(all_scores.device) # (N, M)
1152
+ bm25_scores = bm25_scores.float()
1153
+
1154
+ # min-max về [0, 1]
1155
+ bm25_scores = (
1156
+ bm25_scores - bm25_scores.min(dim=-1, keepdim=True).values
1157
+ ) / (
1158
+ bm25_scores.max(dim=-1, keepdim=True).values
1159
+ - bm25_scores.min(dim=-1, keepdim=True).values
1160
+ + 1e-8
1161
+ )
1162
+ # đổi sang [-1, 1]
1163
+ bm25_scores = bm25_scores * 2 - 1
1164
+
1165
+ all_gt_pos_idxes = list_to_tuple(all_gt_pos_idxes)
1166
+
1167
+ final_score = {}
1168
+ for weight in [0, 0.25, 0.5, 0.75, 1]:
1169
+
1170
+ # score cuối
1171
+ final_scores = all_scores + weight * bm25_scores
1172
+
1173
+ # topk lớn nhất
1174
+ topk_scores, topk_indices = torch.topk(final_scores, k=max_k)
1175
+
1176
+ pred_topk_full = [
1177
+ idx[score > 0].tolist()
1178
+ for score, idx in zip(topk_scores, topk_indices)
1179
+ ]
1180
+
1181
+ pred_topk_full = list_to_tuple(pred_topk_full)
1182
+
1183
+ final_score[weight] = {}
1184
+
1185
+ for k in topks:
1186
+ pred_topk_k = [p[:k] for p in pred_topk_full]
1187
+ final_score[weight][k] = eval_fn(pred_topk_k, all_gt_pos_idxes)
1188
+
1189
+ return final_score
1190
+
1191
+ # %% [code]
1192
+ with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f:
1193
+ data_train = json.load(f)
1194
+
1195
+ with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f:
1196
+ data_test = json.load(f)
1197
+
1198
+ with open(f'{test_dir}/corpus.json', "r", encoding="utf-8") as f:
1199
+ data_corpus = json.load(f)
1200
+
1201
+ print('Train:', len(data_train))
1202
+ print('Test:', len(data_test))
1203
+ print('Corpus:', len(data_corpus))
1204
+
1205
+ # %% [code]
1206
+ # trigger_types = sorted(list(set([e['label'] for d in data_train + data_test for e in d['issues']]))) # NBR : Neighbor relation
1207
+ # bio_trigger_types = ['O'] + [f'{prefix}-{trg}' for trg in trigger_types for prefix in ['B', 'I']]
1208
+ # trigger_label2id = {l: i for i, l in enumerate(bio_trigger_types)}
1209
+ # trigger_id2label = {i: l for l, i in trigger_label2id.items()}
1210
+
1211
+ # argument_types = sorted(list(set([a['role'] for d in data_train + data_test for e in d['issues'] for a in e['arguments']])))
1212
+ # bio_argument_types = ['O'] + [f'{prefix}-{arg}' for arg in argument_types for prefix in ['B', 'I']]
1213
+ # argument_label2id = {l: i for i, l in enumerate(bio_argument_types)}
1214
+ # argument_id2label = {i: l for l, i in argument_label2id.items()}
1215
+
1216
+ # label2id = {
1217
+ # 'Trg': trigger_label2id,
1218
+ # 'Arg': argument_label2id,
1219
+ # }
1220
+
1221
+ # id2label = {
1222
+ # 'Trg': trigger_id2label,
1223
+ # 'Arg': argument_id2label,
1224
+ # }
1225
+
1226
+ # %% [code]
1227
+ # zero_events_idxes = []
1228
+ # for idx, d in enumerate(data_train):
1229
+ # if len(d['issues']) == 0:
1230
+ # zero_events_idxes.append(idx)
1231
+
1232
+ # n_zero_events_samples = len(zero_events_idxes)
1233
+ # n_has_events_samples = len(data_train) - n_zero_events_samples
1234
+
1235
+ # random.seed(42)
1236
+ # k = min(int(n_has_events_samples * zero_events_rate), len(zero_events_idxes))
1237
+ # sampled_zero_events_idxes = random.sample(zero_events_idxes, k)
1238
+
1239
+ # new_data_train = []
1240
+ # for idx, d in enumerate(data_train):
1241
+ # if len(d['issues']) == 0:
1242
+ # if idx in sampled_zero_events_idxes:
1243
+ # new_data_train.append(d)
1244
+ # else:
1245
+ # new_data_train.append(d)
1246
+ # data_train = new_data_train
1247
+
1248
+ # print('Train:', len(data_train))
1249
+
1250
+ # %% [code]
1251
+ if debug_only:
1252
+ data_train = data_train[:20]
1253
+ data_test = data_test[:20]
1254
+
1255
+ print('Train:', len(data_train))
1256
+ print('Test:', len(data_test))
1257
+
1258
+ # %% [code]
1259
+ tokenizer = AutoTokenizer.from_pretrained(backbone_model_name)
1260
+
1261
+ # %% [code]
1262
+ print('Experiment name:', state_dict_save_name)
1263
+
1264
+ # %% [code]
1265
+ if not test_only:
1266
+ full_idxes = np.array(range(len(data_train)))
1267
+ training_logs, best_models, last_models = [], [], []
1268
+ start_training_time = time.time()
1269
+ for seed in SEEDS:
1270
+ kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed)
1271
+ generator = torch.Generator()
1272
+ generator.manual_seed(seed)
1273
+
1274
+ corpusset = CorpusDataset(data_corpus, tokenizer, **corpus_memory_params)
1275
+ corpus_loader = DataLoader(corpusset, generator=generator, **val_loader_params)
1276
+ for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)):
1277
+ if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx:
1278
+ continue
1279
+ set_seed(seed)
1280
+
1281
+ train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx]
1282
+
1283
+ trainset = LawRetrievalDataset(data_train, train_idxes, data_corpus, tokenizer, **train_memory_params)
1284
+ valset = LawRetrievalDataset(data_train, val_idxes, data_corpus, tokenizer, **val_memory_params)
1285
+
1286
+ train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params)
1287
+ val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params)
1288
+
1289
+ my_model = EncodeModel(
1290
+ **model_params
1291
+ )
1292
+ total_params = sum(p.numel() for p in my_model.parameters())
1293
+ print(f"Total params: {total_params:,}")
1294
+
1295
+ # optimizer, scheduler = configure_optimizers(my_model, optim_params, scheduler_params)
1296
+ encoder_params = set(map(id, my_model.encoder.parameters()))
1297
+ other_params = [
1298
+ p for p in my_model.parameters()
1299
+ if id(p) not in encoder_params
1300
+ ]
1301
+ optimizer = optim.AdamW([
1302
+ {"params": my_model.encoder.parameters(), "lr": 2e-5},
1303
+ {"params": other_params}
1304
+ ], lr=5e-4)
1305
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
1306
+
1307
+ loss_fn = CustomLoss(
1308
+ **loss_func_params
1309
+ )
1310
+ eval_fn = CustomEvalFn(**eval_func_params)
1311
+ trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}'
1312
+ trainer = Trainer(**trainer_params)
1313
+
1314
+ print(f'Start Training Fold {fold_idx}...')
1315
+ training_log, best_model, last_model = trainer.fit(
1316
+ my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, corpus_loader, eval_fn,
1317
+ start_epoch=1, start_training_time=start_training_time, refresh_every=2,
1318
+ )
1319
+
1320
+ training_logs.append(training_log)
1321
+ best_models.append(best_model)
1322
+ last_models.append(last_model)
1323
+
1324
+ # %% [code]
1325
+ def load_all_state_dicts(folder):
1326
+ files = []
1327
+
1328
+ for file in os.listdir(folder):
1329
+ if file.endswith(".pt") or file.endswith(".pth"):
1330
+ m = re.search(r"f(\d+)", file) # tìm f<số>
1331
+ if m:
1332
+ fold = int(m.group(1))
1333
+ files.append((fold, file))
1334
+
1335
+ # sort theo fold
1336
+ files.sort(key=lambda x: x[0])
1337
+
1338
+ state_dicts = []
1339
+ for fold, file in files:
1340
+ path = os.path.join(folder, file)
1341
+ print(f"Loading fold {fold}: {file}")
1342
+ state_dict = torch.load(path, map_location="cpu")
1343
+ state_dicts.append(state_dict)
1344
+
1345
+ return state_dicts
1346
+
1347
+ if test_only:
1348
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"])
1349
+ get_ipython().system('rm -rf .cache .gitattributes')
1350
+
1351
+ best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s")
1352
+ last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts")
1353
+
1354
+ # %% [code]
1355
+ os.makedirs(f'{checkpoints_dir}/results', exist_ok=True)
1356
+ testset = LawRetrievalDataset(data_test, range(len(data_test)), data_corpus, tokenizer, **val_memory_params)
1357
+ generator = torch.Generator()
1358
+ test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params)
1359
+ eval_fn = CustomEvalFn(**eval_func_params)
1360
+ my_model = EncodeModel(
1361
+ **model_params
1362
+ )
1363
+ total_params = sum(p.numel() for p in my_model.parameters())
1364
+ print(f"Total params: {total_params:,}")
1365
+
1366
+ # %% [code]
1367
+ start_time = time.time()
1368
+ encoded_docs = encode_corpus(best_models, my_model, corpus_loader, device)
1369
+ best_score = test(best_models, my_model, test_loader, device, eval_fn, encoded_docs, bm25_scores)
1370
+
1371
+ encoded_docs = encode_corpus(last_models, my_model, corpus_loader, device)
1372
+ last_score = test(last_models, my_model, test_loader, device, eval_fn, encoded_docs, bm25_scores)
1373
+
1374
+ result_test = {"Best model": best_score, "Last model": last_score}
1375
+
1376
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f:
1377
+ json.dump(result_test, f, ensure_ascii=False, indent=2)
1378
+
1379
+ print('Test:', time.time() - start_time, 's --> Done!')
1380
+
1381
+ # %% [code]
1382
+ def dict_to_df(data):
1383
+ """
1384
+ data format:
1385
+ {
1386
+ model_name: {
1387
+ weight: {
1388
+ topk: {
1389
+ metric: value
1390
+ }
1391
+ }
1392
+ }
1393
+ }
1394
+ """
1395
+
1396
+ row_tuples = []
1397
+ row_values = []
1398
+
1399
+ # ===== lấy model đầu tiên =====
1400
+ first_model = next(iter(data.values()))
1401
+
1402
+ # ===== weight keys =====
1403
+ weight_keys = list(first_model.keys())
1404
+
1405
+ # ===== topk keys =====
1406
+ first_weight = next(iter(first_model.values()))
1407
+ topk_keys = list(first_weight.keys())
1408
+
1409
+ # ===== metric keys =====
1410
+ first_topk = next(iter(first_weight.values()))
1411
+ metrics = list(first_topk.keys())
1412
+
1413
+ for weight in weight_keys:
1414
+
1415
+ for topk in topk_keys:
1416
+
1417
+ # ===== multi index row =====
1418
+ row_tuples.append((weight, topk))
1419
+
1420
+ row = {}
1421
+
1422
+ for model_name, model_data in data.items():
1423
+
1424
+ for metric in metrics:
1425
+
1426
+ row[(model_name, metric)] = (
1427
+ model_data[weight][topk][metric]
1428
+ )
1429
+
1430
+ row_values.append(row)
1431
+
1432
+ # ===== dataframe =====
1433
+ df = pd.DataFrame(row_values)
1434
+
1435
+ # ===== multi columns =====
1436
+ df.columns = pd.MultiIndex.from_tuples(
1437
+ df.columns,
1438
+ names=["model", "metric"]
1439
+ )
1440
+
1441
+ # ===== multi index =====
1442
+ df.index = pd.MultiIndex.from_tuples(
1443
+ row_tuples,
1444
+ names=["weight", "topk"]
1445
+ )
1446
+
1447
+ return df
1448
+
1449
+ result_test_df = dict_to_df(result_test)
1450
+ result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx")
1451
+ result_test_df
1452
+
1453
+ # %% [code]
1454
+ key = ("Best model", "recall")
1455
+
1456
+ result_test_df_best = (
1457
+ result_test_df
1458
+ .sort_values(by=key, ascending=False)
1459
+ .groupby(level="weight")
1460
+ .head(1)
1461
+ )
1462
+ result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx")
1463
+ result_test_df_best
1464
+
1465
+ # %% [code]
1466
+ def get_avg_best_score(logs):
1467
+ return float(np.mean([list(log.values())[-1]['best_score'] for log in logs]))
1468
+
1469
+ def get_avg_log(logs, epochs):
1470
+ avg_log = {}
1471
+
1472
+ for epoch in range(1, epochs + 1):
1473
+ val_score = 0.0
1474
+ train_loss = 0.0
1475
+ n_eval = 0
1476
+
1477
+ for idx in range(len(logs)):
1478
+ log = logs[idx].get(epoch, logs[idx].get(str(epoch)))
1479
+ if log is None:
1480
+ continue
1481
+
1482
+ val_score += log.get('val_score', 0.0)
1483
+ train_loss += log.get('train_loss', 0.0)
1484
+ n_eval += 1
1485
+
1486
+ if n_eval == 0:
1487
+ continue
1488
+
1489
+ avg_log[epoch] = {
1490
+ 'train_loss': train_loss / n_eval,
1491
+ 'val_score': val_score / n_eval if val_score != 0 else float('inf')
1492
+ }
1493
+
1494
+ return avg_log
1495
+
1496
+ def parse_label_key(label: str):
1497
+ try:
1498
+ first = float(label.split('_', 1)[0]) # số đầu: trước dấu _
1499
+ last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0])
1500
+ return first, last
1501
+ except:
1502
+ return (0, 0)
1503
+
1504
+ def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)):
1505
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
1506
+
1507
+ # ===== Plot Train Loss =====
1508
+ for name, log in logs_dict.items():
1509
+ epochs = sorted(log.keys())
1510
+ train_loss = [log[e]['train_loss'] for e in epochs]
1511
+ axes[0].plot(epochs, train_loss, label=name)
1512
+
1513
+ axes[0].set_xlabel('Epoch')
1514
+ axes[0].set_ylabel('Train Loss')
1515
+ axes[0].set_title('Training Loss')
1516
+ axes[0].grid(True)
1517
+
1518
+ # ===== Plot Validation Score =====
1519
+ for name, log in logs_dict.items():
1520
+ epochs = sorted(log.keys())
1521
+ val_score = [log[e]['val_score'] for e in epochs]
1522
+ axes[1].plot(epochs, val_score, label=name)
1523
+
1524
+ axes[1].set_xlabel('Epoch')
1525
+ axes[1].set_ylabel('Validation Score')
1526
+ axes[1].set_title('Validation Score')
1527
+ axes[1].grid(True)
1528
+
1529
+ # ===== Shared Legend =====
1530
+ handles, labels = axes[0].get_legend_handles_labels()
1531
+ pairs = list(zip(handles, labels))
1532
+ pairs_sorted = sorted(
1533
+ pairs,
1534
+ key=lambda x: parse_label_key(x[1])
1535
+ )
1536
+ handles_sorted, labels_sorted = zip(*pairs_sorted)
1537
+
1538
+ axes[0].legend(
1539
+ handles_sorted,
1540
+ labels_sorted,
1541
+ loc='center left',
1542
+ bbox_to_anchor=(1.01, 0.5),
1543
+ borderaxespad=0.
1544
+ )
1545
+
1546
+ plt.tight_layout(rect=[0, 0, 1, 1])
1547
+
1548
+ if save_path is not None:
1549
+ os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None
1550
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
1551
+
1552
+ plt.show()
1553
+
1554
+ # %% [code]
1555
+ if not test_only:
1556
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*lr*.json"], ignore_patterns=[])
1557
+ get_ipython().system('rm -rf .cache .gitattributes')
1558
+
1559
+ # %% [code]
1560
+ if not test_only:
1561
+ experiments = {}
1562
+ for experiment in os.listdir(pretrained_dir):
1563
+ experiment_logs = []
1564
+ try:
1565
+ for seed in SEEDS:
1566
+ for fold_idx in range(nfolds):
1567
+ with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f:
1568
+ experiment_log = json.load(f)
1569
+ experiment_logs.append(experiment_log)
1570
+ except:
1571
+ pass
1572
+ experiments[experiment] = get_avg_log(experiment_logs, 1000)
1573
+ experiments[state_dict_save_name] = get_avg_log(training_logs, 1000)
1574
+
1575
+ # %% [code]
1576
+ if not test_only:
1577
+ score = get_avg_best_score(training_logs)
1578
+ state_dict_save_name, score
1579
+
1580
+ # %% [code]
1581
+ if not test_only:
1582
+ plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5))
1583
+
15_40_negs_19/lasts/15_40_negs_19_s26092004_f0_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c74fc9b2aa45c5acd1f17b9a14c4fb9b2d902452c1f6693905ee384decc3aa6
3
+ size 544003321
15_40_negs_19/lasts/15_40_negs_19_s26092004_f1_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ebc914aed7a7c2f88c173bb3b44a7eda57ba9c699e0b9213a414c095c853030
3
+ size 544003321
15_40_negs_19/lasts/15_40_negs_19_s26092004_f2_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1cb7995466bc630534a6dcef5eaba54ceb11ba9cb10a0ae9e730695239752438
3
+ size 544003321
15_40_negs_19/lasts/15_40_negs_19_s26092004_f3_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ed2add6d252216023688dd38832d498f53b39be06f067e5fdc7b621e7406b19
3
+ size 544003321
15_40_negs_19/lasts/15_40_negs_19_s26092004_f4_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09290c459964594ca570004f16ff4647fb31c078982f3d30d288f0a04e2b097a
3
+ size 544003321
15_40_negs_19/logs/15_40_negs_19_log_plot.jpg ADDED

Git LFS Details

  • SHA256: 9e7d301f915dabe4c5807eabcb4186eb58916690117f039017e23914873d0db2
  • Pointer size: 131 Bytes
  • Size of remote file: 550 kB
15_40_negs_19/logs/15_40_negs_19_s26092004_f0_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 0.0005], "train_loss": 3.1171696186065674, "total": 3.11716974060671, "contrastive_loss": 2.9843449927492682, "triplet_loss": 0.25982441471571904}, "2": {"lr": [1.988303923565381e-05, 0.0004969282409784868], "train_loss": 2.9655439853668213, "total": 2.9655439191837374, "contrastive_loss": 2.8364519100125416, "triplet_loss": 0.25355351170568563}, "3": {"lr": [1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 2.858736038208008, "total": 2.85873607010347, "contrastive_loss": 2.718403793896321, "triplet_loss": 0.2752926421404682}, "4": {"lr": [1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 2.6668670177459717, "total": 2.6668670552231397, "contrastive_loss": 2.5340147496864547, "triplet_loss": 0.2612876254180602}, "5": {"lr": [1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 2.6551263332366943, "total": 2.655126259079745, "contrastive_loss": 2.5232576606265678, "triplet_loss": 0.26066053511705684}, "6": {"lr": [1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 2.5189194679260254, "total": 2.5189194695207986, "contrastive_loss": 2.3927779692072533, "triplet_loss": 0.25104515050167225}, "7": {"lr": [1.60839598967785e-05, 0.00039715242044697206], "train_loss": 2.5252931118011475, "total": 2.5252930504023827, "contrastive_loss": 2.393953967652592, "triplet_loss": 0.26107859531772576}, "8": {"lr": [1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 2.317253828048706, "total": 2.31725380173495, "contrastive_loss": 2.1961318816628346, "triplet_loss": 0.243938127090301, "val_score": 0.8216369047619048, "best_score": 0.8216369047619048, "new_best_model": true, "recall": 0.8216369047619048, "mAP": 0.30688184996220697, "mRP": 0.29426240079365074}, "9": {"lr": [1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 2.2397305965423584, "total": 2.2397305797972407, "contrastive_loss": 2.1204050121498748, "triplet_loss": 0.23933946488294314, "val_score": 0.8193452380952381, "best_score": 0.8216369047619048, "new_best_model": false, "recall": 0.8193452380952381, "mAP": 0.30983159446649045, "mRP": 0.3007733134920635}, "10": {"lr": [1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 2.094189405441284, "total": 2.0941894939511916, "contrastive_loss": 1.9823402226170568, "triplet_loss": 0.225752508361204, "val_score": 0.8246130952380952, "best_score": 0.8246130952380952, "new_best_model": true, "recall": 0.8246130952380952, "mAP": 0.3140297904462711, "mRP": 0.303297619047619}, "11": {"lr": [1.0500000000000003e-05, 0.0002505], "train_loss": 2.0670664310455322, "total": 2.0670664095160953, "contrastive_loss": 1.9554602581521738, "triplet_loss": 0.225752508361204, "val_score": 0.810595238095238, "best_score": 0.8246130952380952, "new_best_model": false, "recall": 0.810595238095238, "mAP": 0.30970591458805763, "mRP": 0.2991483134920635}, "12": {"lr": [9.013872582117811e-06, 0.00021146960097246246], "train_loss": 1.944291591644287, "total": 1.9442916410822533, "contrastive_loss": 1.8392999451295986, "triplet_loss": 0.20934364548494983, "val_score": 0.817470238095238, "best_score": 0.8246130952380952, "new_best_model": false, "recall": 0.817470238095238, "mAP": 0.3121760706018517, "mRP": 0.29975049603174597}, "13": {"lr": [7.564338553438001e-06, 0.00017340025990345064], "train_loss": 1.9428149461746216, "total": 1.9428149577367266, "contrastive_loss": 1.8373329392244984, "triplet_loss": 0.2095526755852843, "val_score": 0.8001785714285714, "best_score": 0.8246130952380952, "new_best_model": false, "recall": 0.8001785714285714, "mAP": 0.3020906442901235, "mRP": 0.2904057539682539}, "14": {"lr": [6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 1.8371098041534424, "total": 1.837109824088106, "contrastive_loss": 1.7376839628187708, "triplet_loss": 0.19805602006688963, "val_score": 0.8003869047619048, "best_score": 0.8246130952380952, "new_best_model": false, "recall": 0.8003869047619048, "mAP": 0.30453775214947076, "mRP": 0.29076686507936506}, "15": {"lr": [4.916040103221507e-06, 0.00010384757955302797], "train_loss": 1.8653604984283447, "total": 1.8653605279316472, "contrastive_loss": 1.7639458187447743, "triplet_loss": 0.20202759197324416, "val_score": 0.7851190476190476, "best_score": 0.8246130952380952, "new_best_model": false, "recall": 0.7851190476190476, "mAP": 0.2953934854497354, "mRP": 0.2837088293650792}, "16": {"lr": [3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 1.7865545749664307, "total": 1.7865545470579014, "contrastive_loss": 1.6897003275893603, "triplet_loss": 0.19303929765886288, "val_score": 0.7839285714285714, "best_score": 0.8246130952380952, "new_best_model": false, "recall": 0.7839285714285714, "mAP": 0.2973699751196775, "mRP": 0.2851909722222221}, "17": {"lr": [2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 1.8504287004470825, "total": 1.8504287566628344, "contrastive_loss": 1.74955111602477, "triplet_loss": 0.20129598662207357, "val_score": 0.7634523809523809, "best_score": 0.8246130952380952, "new_best_model": false, "recall": 0.7634523809523809, "mAP": 0.2847859487197654, "mRP": 0.27294791666666635}, "18": {"lr": [2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 1.7933450937271118, "total": 1.793345167485368, "contrastive_loss": 1.6958566110668374, "triplet_loss": 0.19502508361204013, "val_score": 0.7585714285714286, "best_score": 0.8246130952380952, "new_best_model": false, "recall": 0.7585714285714286, "mAP": 0.2857298859126985, "mRP": 0.27471428571428563}}
15_40_negs_19/logs/15_40_negs_19_s26092004_f1_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 0.0005], "train_loss": 3.1211578845977783, "total": 3.1211578471206103, "contrastive_loss": 2.988146523568144, "triplet_loss": 0.2594063545150502}, "2": {"lr": [1.988303923565381e-05, 0.0004969282409784868], "train_loss": 2.9242422580718994, "total": 2.924242306712479, "contrastive_loss": 2.7957255385791178, "triplet_loss": 0.25271739130434784}, "3": {"lr": [1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 2.8210690021514893, "total": 2.8210690093279682, "contrastive_loss": 2.6815420297475963, "triplet_loss": 0.272366220735786}, "4": {"lr": [1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 2.6036505699157715, "total": 2.6036505555628136, "contrastive_loss": 2.472719530596781, "triplet_loss": 0.25731605351170567}, "5": {"lr": [1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 2.5862338542938232, "total": 2.5862338981500836, "contrastive_loss": 2.456502346689486, "triplet_loss": 0.2564799331103679}, "6": {"lr": [1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 2.4340462684631348, "total": 2.434046308332462, "contrastive_loss": 2.3109618732363084, "triplet_loss": 0.24686454849498327}, "7": {"lr": [1.60839598967785e-05, 0.00039715242044697206], "train_loss": 2.4292349815368652, "total": 2.429234941667538, "contrastive_loss": 2.301466313492893, "triplet_loss": 0.25585284280936454}, "8": {"lr": [1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 2.218703508377075, "total": 2.2187034581417224, "contrastive_loss": 2.1015549471545776, "triplet_loss": 0.23683110367892976, "val_score": 0.8263782051282051, "best_score": 0.8263782051282051, "new_best_model": true, "recall": 0.8263782051282051, "mAP": 0.31191706486992954, "mRP": 0.2984997329059828}, "9": {"lr": [1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 2.142184019088745, "total": 2.142183948121342, "contrastive_loss": 2.027104712648934, "triplet_loss": 0.23285953177257526, "val_score": 0.8201282051282052, "best_score": 0.8263782051282051, "new_best_model": false, "recall": 0.8201282051282052, "mAP": 0.30896343207332777, "mRP": 0.29356543803418805}, "10": {"lr": [1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 1.9996830224990845, "total": 1.9996829846232231, "contrastive_loss": 1.892217629729306, "triplet_loss": 0.2140468227424749, "val_score": 0.8176282051282051, "best_score": 0.8263782051282051, "new_best_model": false, "recall": 0.8176282051282051, "mAP": 0.3109337243292973, "mRP": 0.29598851495726486}, "11": {"lr": [1.0500000000000003e-05, 0.0002505], "train_loss": 1.9773849248886108, "total": 1.9773849436271949, "contrastive_loss": 1.8699489835911371, "triplet_loss": 0.21467391304347827, "val_score": 0.801073717948718, "best_score": 0.8263782051282051, "new_best_model": false, "recall": 0.801073717948718, "mAP": 0.30063983590676296, "mRP": 0.2850245726495727}, "12": {"lr": [9.013872582117811e-06, 0.00021146960097246246], "train_loss": 1.8462313413619995, "total": 1.846231415917642, "contrastive_loss": 1.7463605491612668, "triplet_loss": 0.19857859531772576, "val_score": 0.800448717948718, "best_score": 0.8263782051282051, "new_best_model": false, "recall": 0.800448717948718, "mAP": 0.3029001602564103, "mRP": 0.2857780448717949}, "13": {"lr": [7.564338553438001e-06, 0.00017340025990345064], "train_loss": 1.853581190109253, "total": 1.853581151037312, "contrastive_loss": 1.7527921032347409, "triplet_loss": 0.200564381270903, "val_score": 0.7827403846153846, "best_score": 0.8263782051282051, "new_best_model": false, "recall": 0.7827403846153846, "mAP": 0.2926743021638855, "mRP": 0.27536137820512824}, "14": {"lr": [6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 1.7550243139266968, "total": 1.755024275253449, "contrastive_loss": 1.6599059854462792, "triplet_loss": 0.19000836120401338, "val_score": 0.7846153846153846, "best_score": 0.8263782051282051, "new_best_model": false, "recall": 0.7846153846153846, "mAP": 0.2943544712810338, "mRP": 0.2773683226495727}, "15": {"lr": [4.916040103221507e-06, 0.00010384757955302797], "train_loss": 1.7846249341964722, "total": 1.7846248971179974, "contrastive_loss": 1.6875818565139005, "triplet_loss": 0.19335284280936454, "val_score": 0.7644070512820513, "best_score": 0.8263782051282051, "new_best_model": false, "recall": 0.7644070512820513, "mAP": 0.283070864176333, "mRP": 0.26721314102564103}, "16": {"lr": [3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 1.7026790380477905, "total": 1.7026790759236519, "contrastive_loss": 1.6106029497739862, "triplet_loss": 0.18394648829431437, "val_score": 0.7657532051282051, "best_score": 0.8263782051282051, "new_best_model": false, "recall": 0.7657532051282051, "mAP": 0.2848575808065391, "mRP": 0.27017494658119656}, "17": {"lr": [2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 1.761484980583191, "total": 1.7614850200538252, "contrastive_loss": 1.6654880485407086, "triplet_loss": 0.19147157190635453, "val_score": 0.7438782051282051, "best_score": 0.8263782051282051, "new_best_model": false, "recall": 0.7438782051282051, "mAP": 0.2740722707654999, "mRP": 0.2586159188034188}, "18": {"lr": [2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 1.706042766571045, "total": 1.7060427458389946, "contrastive_loss": 1.613318095637803, "triplet_loss": 0.1862458193979933, "val_score": 0.7390865384615385, "best_score": 0.8263782051282051, "new_best_model": false, "recall": 0.7390865384615385, "mAP": 0.2732767434243996, "mRP": 0.25770913461538464}}
15_40_negs_19/logs/15_40_negs_19_s26092004_f2_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 0.0005], "train_loss": 3.1194257736206055, "total": 3.119425795947429, "contrastive_loss": 2.9865240907190636, "triplet_loss": 0.26024247491638797}, "2": {"lr": [1.988303923565381e-05, 0.0004969282409784868], "train_loss": 2.9408247470855713, "total": 2.9408246809024874, "contrastive_loss": 2.8120554027748748, "triplet_loss": 0.2537625418060201}, "3": {"lr": [1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 2.819176435470581, "total": 2.8191765112223037, "contrastive_loss": 2.679746902108591, "triplet_loss": 0.27278428093645485}, "4": {"lr": [1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 2.6171159744262695, "total": 2.6171160541649248, "contrastive_loss": 2.485616320351693, "triplet_loss": 0.25961538461538464}, "5": {"lr": [1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 2.586632013320923, "total": 2.5866319535169313, "contrastive_loss": 2.458088240097199, "triplet_loss": 0.2539715719063545}, "6": {"lr": [1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 2.444108009338379, "total": 2.4441081273515888, "contrastive_loss": 2.3214860488738505, "triplet_loss": 0.24602842809364547}, "7": {"lr": [1.60839598967785e-05, 0.00039715242044697206], "train_loss": 2.4610939025878906, "total": 2.461093864313336, "contrastive_loss": 2.331970827236622, "triplet_loss": 0.25794314381270905}, "8": {"lr": [1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 2.227365732192993, "total": 2.2273657553172033, "contrastive_loss": 2.109634042184887, "triplet_loss": 0.23641304347826086, "val_score": 0.8265865384615385, "best_score": 0.8265865384615385, "new_best_model": true, "recall": 0.8265865384615385, "mAP": 0.30803628663003657, "mRP": 0.2956530448717947}, "9": {"lr": [1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 2.1473047733306885, "total": 2.147304777317621, "contrastive_loss": 2.0319765020772365, "triplet_loss": 0.2326505016722408, "val_score": 0.831474358974359, "best_score": 0.831474358974359, "new_best_model": true, "recall": 0.831474358974359, "mAP": 0.30395951903998786, "mRP": 0.2885096153846151}, "10": {"lr": [1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 2.003793478012085, "total": 2.0037935697115383, "contrastive_loss": 1.895869532556438, "triplet_loss": 0.21592809364548496, "val_score": 0.8357371794871795, "best_score": 0.8357371794871795, "new_best_model": true, "recall": 0.8357371794871795, "mAP": 0.31071954290886594, "mRP": 0.296099893162393}, "11": {"lr": [1.0500000000000003e-05, 0.0002505], "train_loss": 1.9819941520690918, "total": 1.9819942206443353, "contrastive_loss": 1.874325143054975, "triplet_loss": 0.21551003344481606, "val_score": 0.8188621794871794, "best_score": 0.8357371794871795, "new_best_model": false, "recall": 0.8188621794871794, "mAP": 0.3001329507105549, "mRP": 0.2848458867521369}, "12": {"lr": [9.013872582117811e-06, 0.00021146960097246246], "train_loss": 1.8608283996582031, "total": 1.8608284124163879, "contrastive_loss": 1.7599699600883152, "triplet_loss": 0.20066889632107024, "val_score": 0.8188621794871794, "best_score": 0.8357371794871795, "new_best_model": false, "recall": 0.8188621794871794, "mAP": 0.3051606443325193, "mRP": 0.2899618055555556}, "13": {"lr": [7.564338553438001e-06, 0.00017340025990345064], "train_loss": 1.865121841430664, "total": 1.8651218988424958, "contrastive_loss": 1.7634303880774456, "triplet_loss": 0.20234113712374582, "val_score": 0.7992948717948718, "best_score": 0.8357371794871795, "new_best_model": false, "recall": 0.7992948717948718, "mAP": 0.2921100790895061, "mRP": 0.27890625}, "14": {"lr": [6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 1.7723990678787231, "total": 1.7723990858199206, "contrastive_loss": 1.676188491259929, "triplet_loss": 0.19262123745819398, "val_score": 0.7995032051282052, "best_score": 0.8357371794871795, "new_best_model": false, "recall": 0.7995032051282052, "mAP": 0.2982772392441664, "mRP": 0.28680181623931644}, "15": {"lr": [4.916040103221507e-06, 0.00010384757955302797], "train_loss": 1.8008406162261963, "total": 1.8008406521085911, "contrastive_loss": 1.7024976035025605, "triplet_loss": 0.1953386287625418, "val_score": 0.7799198717948719, "best_score": 0.8357371794871795, "new_best_model": false, "recall": 0.7799198717948719, "mAP": 0.28581185208502935, "mRP": 0.27247542735042746}, "16": {"lr": [3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 1.7201329469680786, "total": 1.7201329872360995, "contrastive_loss": 1.6267152103691993, "triplet_loss": 0.18635033444816054, "val_score": 0.7778365384615384, "best_score": 0.8357371794871795, "new_best_model": false, "recall": 0.7778365384615384, "mAP": 0.28914111276455023, "mRP": 0.27675854700854713}, "17": {"lr": [2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 1.7785074710845947, "total": 1.7785075005878972, "contrastive_loss": 1.68122419784699, "triplet_loss": 0.1936663879598662, "val_score": 0.7549198717948719, "best_score": 0.8357371794871795, "new_best_model": false, "recall": 0.7549198717948719, "mAP": 0.2762329878069462, "mRP": 0.2648870192307691}, "18": {"lr": [2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 1.7219746112823486, "total": 1.721974656733382, "contrastive_loss": 1.6282511937578386, "triplet_loss": 0.18739548494983277, "val_score": 0.7534615384615385, "best_score": 0.8357371794871795, "new_best_model": false, "recall": 0.7534615384615385, "mAP": 0.2773733988137634, "mRP": 0.26447889957264964}}
15_40_negs_19/logs/15_40_negs_19_s26092004_f3_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 0.0005], "train_loss": 3.1314284801483154, "total": 3.1314284921091136, "contrastive_loss": 2.998345722721572, "triplet_loss": 0.2591973244147157}, "2": {"lr": [1.988303923565381e-05, 0.0004969282409784868], "train_loss": 2.977572202682495, "total": 2.9775723358460495, "contrastive_loss": 2.8482010755251883, "triplet_loss": 0.2543896321070234}, "3": {"lr": [1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 2.877302885055542, "total": 2.877302801329954, "contrastive_loss": 2.7362909731657608, "triplet_loss": 0.2752926421404682}, "4": {"lr": [1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 2.6792826652526855, "total": 2.6792827083115593, "contrastive_loss": 2.5463807989522365, "triplet_loss": 0.26024247491638797}, "5": {"lr": [1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 2.6457440853118896, "total": 2.6457441961486206, "contrastive_loss": 2.5150901768917224, "triplet_loss": 0.25794314381270905}, "6": {"lr": [1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 2.5233898162841797, "total": 2.5233899374869355, "contrastive_loss": 2.3977869346389005, "triplet_loss": 0.25104515050167225}, "7": {"lr": [1.60839598967785e-05, 0.00039715242044697206], "train_loss": 2.5388379096984863, "total": 2.5388379559469065, "contrastive_loss": 2.4069662955293687, "triplet_loss": 0.2608695652173913}, "8": {"lr": [1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 2.3229122161865234, "total": 2.32291231187291, "contrastive_loss": 2.2016325985707565, "triplet_loss": 0.2443561872909699, "val_score": 0.8211858974358974, "best_score": 0.8211858974358974, "new_best_model": true, "recall": 0.8211858974358974, "mAP": 0.3007168588259735, "mRP": 0.2889647435897436}, "9": {"lr": [1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 2.246257781982422, "total": 2.246257871289716, "contrastive_loss": 2.126957207619147, "triplet_loss": 0.24101170568561872, "val_score": 0.8249198717948718, "best_score": 0.8249198717948718, "new_best_model": true, "recall": 0.8249198717948718, "mAP": 0.3058986231939356, "mRP": 0.2940205662393164}, "10": {"lr": [1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 2.0984981060028076, "total": 2.098498086068144, "contrastive_loss": 1.9864808149561037, "triplet_loss": 0.225752508361204, "val_score": 0.8312820512820512, "best_score": 0.8312820512820512, "new_best_model": true, "recall": 0.8312820512820512, "mAP": 0.31269115961199306, "mRP": 0.3001479700854703}, "11": {"lr": [1.0500000000000003e-05, 0.0002505], "train_loss": 2.0690884590148926, "total": 2.0690885307796822, "contrastive_loss": 1.9577316233146949, "triplet_loss": 0.22408026755852842, "val_score": 0.8200320512820513, "best_score": 0.8312820512820512, "new_best_model": false, "recall": 0.8200320512820513, "mAP": 0.3064226916505563, "mRP": 0.29122168803418785}, "12": {"lr": [9.013872582117811e-06, 0.00021146960097246246], "train_loss": 1.93962562084198, "total": 1.9396256156589673, "contrastive_loss": 1.8351281207540762, "triplet_loss": 0.2078804347826087, "val_score": 0.8247115384615384, "best_score": 0.8312820512820512, "new_best_model": false, "recall": 0.8247115384615384, "mAP": 0.3112411301468594, "mRP": 0.2959300213675214}, "13": {"lr": [7.564338553438001e-06, 0.00017340025990345064], "train_loss": 1.9388900995254517, "total": 1.938890131819607, "contrastive_loss": 1.8337806523045568, "triplet_loss": 0.20923913043478262, "val_score": 0.8007692307692308, "best_score": 0.8312820512820512, "new_best_model": false, "recall": 0.8007692307692308, "mAP": 0.2988083025963233, "mRP": 0.2838952991452991}, "14": {"lr": [6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 1.8317793607711792, "total": 1.8317793523986203, "contrastive_loss": 1.7327419523411371, "triplet_loss": 0.19816053511705686, "val_score": 0.8012820512820513, "best_score": 0.8312820512820512, "new_best_model": false, "recall": 0.8012820512820513, "mAP": 0.3030822240452446, "mRP": 0.28900988247863235}, "15": {"lr": [4.916040103221507e-06, 0.00010384757955302797], "train_loss": 1.8575176000595093, "total": 1.857517612419001, "contrastive_loss": 1.7566124141016932, "triplet_loss": 0.20129598662207357, "val_score": 0.7846153846153846, "best_score": 0.8312820512820512, "new_best_model": false, "recall": 0.7846153846153846, "mAP": 0.29044165500780084, "mRP": 0.2746487713675213}, "16": {"lr": [3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 1.77456533908844, "total": 1.774565323539402, "contrastive_loss": 1.6785964200329222, "triplet_loss": 0.19063545150501673, "val_score": 0.7802403846153846, "best_score": 0.8312820512820512, "new_best_model": false, "recall": 0.7802403846153846, "mAP": 0.29246244700515533, "mRP": 0.2771730769230769}, "17": {"lr": [2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 1.8297371864318848, "total": 1.8297372263012124, "contrastive_loss": 1.730079676395276, "triplet_loss": 0.1992056856187291, "val_score": 0.7600320512820513, "best_score": 0.8312820512820512, "new_best_model": false, "recall": 0.7600320512820513, "mAP": 0.2791634140550807, "mRP": 0.2631722756410257}, "18": {"lr": [2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 1.774262547492981, "total": 1.7742625973296404, "contrastive_loss": 1.6777860201322115, "triplet_loss": 0.19188963210702342, "val_score": 0.7587820512820513, "best_score": 0.8312820512820512, "new_best_model": false, "recall": 0.7587820512820513, "mAP": 0.2802019300722428, "mRP": 0.26546741452991457}}
15_40_negs_19/logs/15_40_negs_19_s26092004_f4_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 0.0005], "train_loss": 3.1516640186309814, "total": 3.151663993911998, "contrastive_loss": 3.018289521386392, "triplet_loss": 0.25982441471571904}, "2": {"lr": [1.988303923565381e-05, 0.0004969282409784868], "train_loss": 3.0036609172821045, "total": 3.003660884589256, "contrastive_loss": 2.873781950577446, "triplet_loss": 0.25418060200668896}, "3": {"lr": [1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 2.8833868503570557, "total": 2.883386924514005, "contrastive_loss": 2.742359582397053, "triplet_loss": 0.2752926421404682}, "4": {"lr": [1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 2.6958534717559814, "total": 2.695853447036998, "contrastive_loss": 2.5620668341084865, "triplet_loss": 0.262123745819398}, "5": {"lr": [1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 2.6569666862487793, "total": 2.6569667037912836, "contrastive_loss": 2.5254155289767977, "triplet_loss": 0.2585702341137124}, "6": {"lr": [1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 2.5223934650421143, "total": 2.5223933701531145, "contrastive_loss": 2.3964721271425584, "triplet_loss": 0.25062709030100333}, "7": {"lr": [1.60839598967785e-05, 0.00039715242044697206], "train_loss": 2.538316249847412, "total": 2.5383161972198995, "contrastive_loss": 2.40635737288357, "triplet_loss": 0.262123745819398}, "8": {"lr": [1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 2.335059642791748, "total": 2.335059736883361, "contrastive_loss": 2.213106391421927, "triplet_loss": 0.24561036789297658, "val_score": 0.8190865384615384, "best_score": 0.8190865384615384, "new_best_model": true, "recall": 0.8190865384615384, "mAP": 0.2993911898190543, "mRP": 0.2877724358974359}, "9": {"lr": [1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 2.2513844966888428, "total": 2.2513846202837584, "contrastive_loss": 2.1315795490175584, "triplet_loss": 0.24247491638795987, "val_score": 0.8138942307692307, "best_score": 0.8190865384615384, "new_best_model": false, "recall": 0.8138942307692307, "mAP": 0.29852555263024005, "mRP": 0.28592868589743586}, "10": {"lr": [1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 2.1072781085968018, "total": 2.10727816680602, "contrastive_loss": 1.9944927509014423, "triplet_loss": 0.22846989966555184, "val_score": 0.8170192307692308, "best_score": 0.8190865384615384, "new_best_model": false, "recall": 0.8170192307692308, "mAP": 0.3041635255562339, "mRP": 0.29063060897435883}, "11": {"lr": [1.0500000000000003e-05, 0.0002505], "train_loss": 2.081670045852661, "total": 2.081670142336434, "contrastive_loss": 1.9690559923050794, "triplet_loss": 0.2270066889632107, "val_score": 0.805448717948718, "best_score": 0.8190865384615384, "new_best_model": false, "recall": 0.805448717948718, "mAP": 0.29550365473646706, "mRP": 0.28156677350427356}, "12": {"lr": [9.013872582117811e-06, 0.00021146960097246246], "train_loss": 1.950080156326294, "total": 1.9500801826400502, "contrastive_loss": 1.8448078066210285, "triplet_loss": 0.20986622073578595, "val_score": 0.8101442307692307, "best_score": 0.8190865384615384, "new_best_model": false, "recall": 0.8101442307692307, "mAP": 0.3001658659145636, "mRP": 0.28571260683760663}, "13": {"lr": [7.564338553438001e-06, 0.00017340025990345064], "train_loss": 1.9600468873977661, "total": 1.9600468766330477, "contrastive_loss": 1.8532151442307692, "triplet_loss": 0.21331521739130435, "val_score": 0.7872115384615385, "best_score": 0.8190865384615384, "new_best_model": false, "recall": 0.7872115384615385, "mAP": 0.2894139951965473, "mRP": 0.276309561965812}, "14": {"lr": [6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 1.8427588939666748, "total": 1.8427589442020276, "contrastive_loss": 1.7428629693378972, "triplet_loss": 0.19857859531772576, "val_score": 0.7877403846153846, "best_score": 0.8190865384615384, "new_best_model": false, "recall": 0.7877403846153846, "mAP": 0.29399225787715383, "mRP": 0.2805809294871795}, "15": {"lr": [4.916040103221507e-06, 0.00010384757955302797], "train_loss": 1.8841179609298706, "total": 1.8841179174723035, "contrastive_loss": 1.781144464295046, "triplet_loss": 0.20411789297658864, "val_score": 0.7705448717948719, "best_score": 0.8190865384615384, "new_best_model": false, "recall": 0.7705448717948719, "mAP": 0.28313406190645773, "mRP": 0.270417735042735}, "16": {"lr": [3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 1.7961944341659546, "total": 1.796194427388169, "contrastive_loss": 1.6986813752547554, "triplet_loss": 0.19471153846153846, "val_score": 0.7710737179487179, "best_score": 0.8190865384615384, "new_best_model": false, "recall": 0.7710737179487179, "mAP": 0.2856894191553045, "mRP": 0.27385229700854696}, "17": {"lr": [2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 1.865045189857483, "total": 1.8650451456025292, "contrastive_loss": 1.7631364394988502, "triplet_loss": 0.20275919732441472, "val_score": 0.7509615384615385, "best_score": 0.8190865384615384, "new_best_model": false, "recall": 0.7509615384615385, "mAP": 0.2750604343669449, "mRP": 0.264383814102564}, "18": {"lr": [2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 1.8052574396133423, "total": 1.8052574336329432, "contrastive_loss": 1.7068630460911371, "triplet_loss": 0.19607023411371238, "val_score": 0.7525160256410256, "best_score": 0.8190865384615384, "new_best_model": false, "recall": 0.7525160256410256, "mAP": 0.2762585199811762, "mRP": 0.2670267094017094}}
15_40_negs_19/r1s/15_40_negs_19_s26092004_f0_r1_vs0.82461_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea12b05970afacb0233a76c5a346df4e15bbb203d0f1c262183946f2889569d4
3
+ size 544005009
15_40_negs_19/r1s/15_40_negs_19_s26092004_f1_r1_vs0.82638_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6435761f1a9fe3b04869a4aca7656e9a2a0e4de1f7b711548134a7b5d8226d72
3
+ size 544005009
15_40_negs_19/r1s/15_40_negs_19_s26092004_f2_r1_vs0.83574_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c74fbcd7e29f769c3217ff0be5fe7667b5e07fb064b4c187526582f302967fe
3
+ size 544005009
15_40_negs_19/r1s/15_40_negs_19_s26092004_f3_r1_vs0.83128_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efde98e85c5887aa20c2636ae1626dcf07d475c69f232992b2785f8396ccc3c5
3
+ size 544005009
15_40_negs_19/r1s/15_40_negs_19_s26092004_f4_r1_vs0.81909_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f58a7ab07e38e4c577f1169bd01502871421051d2f3746662dccf6340e51de1
3
+ size 544005009
15_40_negs_19/results/15_40_negs_19_test.json ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Best model": {
3
+ "0": {
4
+ "5": {
5
+ "recall": 0.7311746987951807,
6
+ "mAP": 0.29634768239625325,
7
+ "mRP": 0.2996046686747005
8
+ },
9
+ "10": {
10
+ "recall": 0.8226656626506024,
11
+ "mAP": 0.3103050869151218,
12
+ "mRP": 0.2996046686747005
13
+ },
14
+ "15": {
15
+ "recall": 0.8685993975903614,
16
+ "mAP": 0.31666905573579285,
17
+ "mRP": 0.2996046686747005
18
+ }
19
+ },
20
+ "0.25": {
21
+ "5": {
22
+ "recall": 0.7801204819277109,
23
+ "mAP": 0.31824976991298703,
24
+ "mRP": 0.3236383032128533
25
+ },
26
+ "10": {
27
+ "recall": 0.8576807228915663,
28
+ "mAP": 0.33356394078297436,
29
+ "mRP": 0.3236383032128533
30
+ },
31
+ "15": {
32
+ "recall": 0.8919427710843374,
33
+ "mAP": 0.34020153438498807,
34
+ "mRP": 0.3236383032128533
35
+ }
36
+ },
37
+ "0.5": {
38
+ "5": {
39
+ "recall": 0.7575301204819277,
40
+ "mAP": 0.30015227576974685,
41
+ "mRP": 0.3053903112449816
42
+ },
43
+ "10": {
44
+ "recall": 0.8320783132530121,
45
+ "mAP": 0.31415802346290034,
46
+ "mRP": 0.3053903112449816
47
+ },
48
+ "15": {
49
+ "recall": 0.8716114457831325,
50
+ "mAP": 0.3204404409491958,
51
+ "mRP": 0.3053903112449816
52
+ }
53
+ },
54
+ "0.75": {
55
+ "5": {
56
+ "recall": 0.7375753012048193,
57
+ "mAP": 0.2889613662985287,
58
+ "mRP": 0.2943398594377524
59
+ },
60
+ "10": {
61
+ "recall": 0.8189006024096386,
62
+ "mAP": 0.3031146560448151,
63
+ "mRP": 0.2943398594377524
64
+ },
65
+ "15": {
66
+ "recall": 0.8561746987951807,
67
+ "mAP": 0.30878041452495586,
68
+ "mRP": 0.2943398594377524
69
+ }
70
+ },
71
+ "1": {
72
+ "5": {
73
+ "recall": 0.7289156626506024,
74
+ "mAP": 0.2821293507362792,
75
+ "mRP": 0.28512173694779225
76
+ },
77
+ "10": {
78
+ "recall": 0.8109939759036144,
79
+ "mAP": 0.2956318652586542,
80
+ "mRP": 0.28512173694779225
81
+ },
82
+ "15": {
83
+ "recall": 0.848644578313253,
84
+ "mAP": 0.30152445524480603,
85
+ "mRP": 0.28512173694779225
86
+ }
87
+ }
88
+ },
89
+ "Last model": {
90
+ "0": {
91
+ "5": {
92
+ "recall": 0.6908885542168675,
93
+ "mAP": 0.2774120440093722,
94
+ "mRP": 0.27663780120482045
95
+ },
96
+ "10": {
97
+ "recall": 0.7906626506024096,
98
+ "mAP": 0.29171843375486195,
99
+ "mRP": 0.27663780120482045
100
+ },
101
+ "15": {
102
+ "recall": 0.8369728915662651,
103
+ "mAP": 0.2969115673657291,
104
+ "mRP": 0.27663780120482045
105
+ }
106
+ },
107
+ "0.25": {
108
+ "5": {
109
+ "recall": 0.7740963855421686,
110
+ "mAP": 0.31292189591700287,
111
+ "mRP": 0.31780245983935934
112
+ },
113
+ "10": {
114
+ "recall": 0.8509036144578314,
115
+ "mAP": 0.32840924461735954,
116
+ "mRP": 0.31780245983935934
117
+ },
118
+ "15": {
119
+ "recall": 0.8829066265060241,
120
+ "mAP": 0.33434082861652636,
121
+ "mRP": 0.31780245983935934
122
+ }
123
+ },
124
+ "0.5": {
125
+ "5": {
126
+ "recall": 0.7466114457831325,
127
+ "mAP": 0.2967767946787162,
128
+ "mRP": 0.3007718373493991
129
+ },
130
+ "10": {
131
+ "recall": 0.8279367469879518,
132
+ "mAP": 0.3109618359350114,
133
+ "mRP": 0.3007718373493991
134
+ },
135
+ "15": {
136
+ "recall": 0.8610692771084337,
137
+ "mAP": 0.31643619841991616,
138
+ "mRP": 0.3007718373493991
139
+ }
140
+ },
141
+ "0.75": {
142
+ "5": {
143
+ "recall": 0.733433734939759,
144
+ "mAP": 0.28489949380856855,
145
+ "mRP": 0.2896209839357442
146
+ },
147
+ "10": {
148
+ "recall": 0.8162650602409639,
149
+ "mAP": 0.29845161897590466,
150
+ "mRP": 0.2896209839357442
151
+ },
152
+ "15": {
153
+ "recall": 0.8497740963855421,
154
+ "mAP": 0.3044438175612884,
155
+ "mRP": 0.2896209839357442
156
+ }
157
+ },
158
+ "1": {
159
+ "5": {
160
+ "recall": 0.7213855421686747,
161
+ "mAP": 0.2767009705488627,
162
+ "mRP": 0.28113077309237045
163
+ },
164
+ "10": {
165
+ "recall": 0.8064759036144579,
166
+ "mAP": 0.29008658144482796,
167
+ "mRP": 0.28113077309237045
168
+ },
169
+ "15": {
170
+ "recall": 0.8460090361445783,
171
+ "mAP": 0.2962527849284259,
172
+ "mRP": 0.28113077309237045
173
+ }
174
+ }
175
+ }
176
+ }
15_40_negs_19/results/15_40_negs_19_test_df.xlsx ADDED
Binary file (6.26 kB). View file
 
15_40_negs_19/results/15_40_negs_19_test_df_best.xlsx ADDED
Binary file (5.57 kB). View file