kltn-experiments / 1_pointer_base_entities_4 /1_pointer_base_entities_4.py
SS3M's picture
Upload 1_pointer_base_entities_4's state dict
fabb055 verified
# %% [code]
get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch] pytorch-crf')
# %% [code]
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, TensorDataset, DataLoader
import torch.nn.functional as F
import albumentations as albu
from transformers import AutoTokenizer, AutoModel
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from positional_encodings.torch_encodings import PositionalEncoding1D
from torchcrf import CRF
from sklearn.metrics import f1_score
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold
from sklearn.metrics import precision_recall_fscore_support
from timm.utils import ModelEmaV3
import timm
import os
import gc
import json
from pathlib import Path
import pickle
from tqdm.auto import tqdm
import copy
import numpy as np
import pandas as pd
import polars as pl
from PIL import Image
import time
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
from multiprocessing import Manager as MemoryManager
from functools import lru_cache
import shutil
import glob
import cv2
import random
import re
import joblib
import math
from huggingface_hub import HfApi, snapshot_download
import evaluate
from underthesea import word_tokenize as vi_tokenize_tool
import spacy
en_tokenize_tool = spacy.load("en_core_web_sm")
from collections import defaultdict, Counter
# %% [code]
# Global config
SEEDS = [26092004]
topk = 1
nfolds = 5
only_fold_idx = 0
test_only = 0
debug_only = 0
# Config thư mục
dataset = 'kltn/only_entities' # conll003, ontonotes, phoner, vietbio, vietmed, vimed, kltn/only_entities, kltn/raw
root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' ## Thư mục chứa file train, val, test
train_dir = f'{root_dir}'
# val_dir = f'{root_dir}/val'
test_dir = f'{root_dir}'
# Config checkpoints
# Config training
epochs = 18 if not debug_only else 2
batch_size = 32
device = "cuda" if torch.cuda.is_available() else "cpu"
# # Thêm biến toàn cục nào đó vào đây
repo_name = 'SS3M/kltn-experiments'
state_dict_save_name = "1_pointer_base_entities_4"
checkpoints_dir = state_dict_save_name
pretrained_dir = "/kaggle/working"
os.makedirs(f'{checkpoints_dir}', exist_ok=True)
backbone_model_name = "bert-base-uncased" if dataset in ["conll003", "ontonotes"] else "vinai/phobert-base"
word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == dataset in ["conll003", "ontonotes"] else vi_tokenize_tool(text)
max_len_dict = {
'kltn/raw': 256,
'kltn/only_entities': 68,
'conll003': 46,
'ontonotes': 61,
'phoner': 68,
'vietbio': 125,
'vietmed': 36,
'vimed': 100,
}
zero_entities_rate_dict = {
'kltn/raw': 1000,
'kltn/only_entities': 0.2,
'conll003': 1000, # mean keep all zero-entities samples
'ontonotes': 1000,
'phoner': 1000,
'vietbio': 1000,
'vietmed': 1000,
'vimed': 1000,
}
max_len = max_len_dict[dataset]
max_n_parts = 1
max_span_len = 10
zero_entities_rate = zero_entities_rate_dict[dataset]
# Trainer
trainer_params = {
"training_time": "00:11:30:00",
"eval_mode": "max",
"topk": topk,
"save_name": state_dict_save_name,
"save_best": True,
"save_last": True,
"device": device,
"logging": True,
"logging_file": True,
"checkpoints_dir": checkpoints_dir,
"early_stopping": 30,
"eval_from_ratio": 0.4,
"eval_every": 1,
"schedule_in_step": False,
"use_ema": True,
"ema_from_ratio": 0.3,
"ema_decay": 0.9995,
"max_grad_norm": 200.0,
"return_best": True,
"return_last": True,
}
# Memory
train_memory_params = {
'max_len': max_len,
'max_n_parts': max_n_parts,
}
val_memory_params = {
'max_len': max_len,
'max_n_parts': max_n_parts,
}
# Data Loader
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
train_loader_params = {
'batch_size': batch_size,
'shuffle': True,
'pin_memory':True,
'num_workers': 2,
'drop_last': False,
'worker_init_fn': seed_worker,
'persistent_workers': False,
}
val_loader_params = {
'batch_size': batch_size,
'shuffle': False,
'pin_memory':True,
'num_workers': 1,
'drop_last': False,
'worker_init_fn': seed_worker,
'persistent_workers': False,
}
# Model
model_params = {
'backbone_model_name': backbone_model_name,
}
# Loss Func
loss_func_params = {
'lambda_ce': 1.0,
}
eval_func_params = {}
# Optim
optim_params = {
'name': 'AdamW',
'lr': 1e-4,
'weight_decay': 1e-4,
}
scheduler_params = {
'name': 'CosineAnnealingLR',
'T_max': 20, # Số epoch để hoàn thành một chu kỳ giảm LR
'eta_min': 1e-6 # Learning rate nhỏ nhất trong chu kỳ
}
# %% [code]
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if using multi-GPU
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(seed)
# %% [code]
class CustomLoss(nn.Module):
def __init__(self, lambda_ce=1.0):
super().__init__()
self.lambda_ce = lambda_ce
self.ce = nn.CrossEntropyLoss(ignore_index=-100)
def forward(
self,
start_logits, start_labels,
end_logits, end_labels,
):
device = start_logits.device
# ===== TRG CE =====
B, L, C = start_logits.shape
start_logits_flat = start_logits.view(B * L, C)
start_labels_flat = start_labels.view(-1)
start_loss = self.ce(start_logits_flat, start_labels_flat) # (B*N,)
end_logits_flat = end_logits.view(B * L, C)
end_labels_flat = end_labels.view(-1)
end_loss = self.ce(end_logits_flat, end_labels_flat) # (B*N,)
return {
"total": start_loss + end_loss,
"start_loss": start_loss,
"end_loss": end_loss,
}
# %% [code]
## Viết eval_fn vào đây
# Bỏ hết eval_fn và trọng số vào đây
class CustomEvalFn(nn.Module):
def __init__(self):
super().__init__()
def compute_f1(self, tp, fp, fn):
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
return precision, recall, f1
def forward(self, pred, gold):
pred_set = set(pred)
gold_set = set(gold)
tp = len(pred_set & gold_set)
fp = len(pred_set - gold_set)
fn = len(gold_set - pred_set)
precision, recall, f1 = self.compute_f1(tp, fp, fn)
return {
f"precision": precision,
f"recall": recall,
f"f1": f1,
}
class SpanErrorAnalyzer:
def __init__(self, pad_token_id=0):
self.pad_token_id = pad_token_id
# ===== helper =====
def _to_set(self, data):
"""
data: list of (b, tuple(ids))
-> dict[b] = set(tuple(ids))
"""
res = defaultdict(set)
for b, ids in data:
ids = tuple([i for i in ids if i != self.pad_token_id])
if len(ids) > 0:
res[b].add(ids)
return res
def _iou(self, a, b):
"""
a, b: tuple(ids)
"""
set_a, set_b = set(a), set(b)
inter = len(set_a & set_b)
union = len(set_a | set_b)
if union == 0:
return 0.0
return inter / union
def _boundary_error(self, pred, gold):
"""
đo lệch boundary dựa trên overlap prefix/suffix
"""
# left match
left = 0
for i in range(min(len(pred), len(gold))):
if pred[i] == gold[i]:
left += 1
else:
break
# right match
right = 0
for i in range(1, min(len(pred), len(gold)) + 1):
if pred[-i] == gold[-i]:
right += 1
else:
break
return {
"left_match": left,
"right_match": right,
"pred_len": len(pred),
"gold_len": len(gold),
}
# ===== main =====
def analyze(self, preds, golds):
pred_map = self._to_set(preds)
gold_map = self._to_set(golds)
all_batches = set(pred_map.keys()) | set(gold_map.keys())
stats = Counter()
detailed_errors = []
for b in all_batches:
pset = pred_map.get(b, set())
gset = gold_map.get(b, set())
matched_gold = set()
# ===== check predictions =====
for p in pset:
if p in gset:
stats["exact_match"] += 1
matched_gold.add(p)
else:
# tìm gold gần nhất
best_iou = 0
best_g = None
for g in gset:
iou = self._iou(p, g)
if iou > best_iou:
best_iou = iou
best_g = g
if best_iou > 0:
stats["partial_match"] += 1
boundary = self._boundary_error(p, best_g)
detailed_errors.append({
"type": "boundary_error",
"batch": b,
"pred": p,
"gold": best_g,
"iou": best_iou,
**boundary
})
else:
if b not in gold_map:
stats["no_event_sample"] += 1
err_type = "no_event_sample"
else:
stats["completely_wrong"] += 1
err_type = "completely_wrong"
detailed_errors.append({
"type": err_type,
"batch": b,
"pred": p
})
# ===== check missing =====
for g in gset:
if g not in matched_gold:
# check if any pred overlaps
overlap = any(self._iou(p, g) > 0 for p in pset)
if overlap:
stats["miss_with_overlap"] += 1
else:
stats["miss"] += 1
detailed_errors.append({
"type": "miss",
"batch": b,
"gold": g
})
return {
"summary": {
"exact_match": (stats["exact_match"], stats["exact_match"] / len(preds)),
"partial_match": (stats["partial_match"], stats["partial_match"] / len(preds)),
"no_event_sample": (stats["no_event_sample"], stats["no_event_sample"] / len(preds)),
"completely_wrong": (stats["completely_wrong"], stats["completely_wrong"] / len(preds)),
"miss": (stats["miss"], stats["miss"] / len(golds)),
"miss_with_overlap": (stats["miss_with_overlap"], stats["miss_with_overlap"] / len(golds)),
},
"details": detailed_errors
}
# %% [code]
## Viết cấu trúc model vào đây
class MLP(nn.Module):
def __init__(self, in_size, hid_size, out_size):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_size, hid_size),
nn.ReLU(),
nn.Linear(hid_size, out_size)
)
def forward(self, x):
return self.mlp(x)
class IEModel(nn.Module):
def __init__(self, backbone_model_name, num_labels):
super().__init__()
self.encoder = AutoModel.from_pretrained(backbone_model_name)
hidden_size = self.encoder.config.hidden_size
self.start_classifier = MLP(hidden_size, hidden_size, num_labels)
self.end_classifier = MLP(hidden_size, hidden_size, num_labels)
def encode(self, input_ids, attention_mask):
B, n_parts, L = input_ids.shape
input_ids = input_ids.view(-1, L)
attention_mask = attention_mask.view(-1, L)
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state # B * n_parts, L, H
hidden_states = hidden_states.view(B, n_parts, L, -1).reshape(B, n_parts*L, -1) # B, L, H
return hidden_states
def get_logits(self, hidden_states):
start_logits = self.start_classifier(hidden_states) # B, N, classes
end_logits = self.end_classifier(hidden_states) # B, N, classes
return start_logits, end_logits
def forward(self, input_ids, attention_mask, labels=None):
hidden_states = self.encode(input_ids, attention_mask)
start_logits, end_logits = self.get_logits(hidden_states)
return start_logits, end_logits
def test():
model = nn.DataParallel(IEModel(backbone_model_name, 7)).to(device)
model.eval()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total params: {total_params:,}")
vocab_size = model.module.encoder.config.vocab_size
max_len = model.module.encoder.config.max_position_embeddings
bz = 32
i = torch.randint(0, vocab_size, (bz, 5, 10)).to(device)
a = torch.ones(bz, 5, 10).to(device)
g = torch.ones(bz, 3, 2, dtype=torch.long).to(device)
with torch.no_grad():
r = model(i, a)
if type(r) == tuple:
print([r[i].shape if type(r[i]) == type(torch.Tensor()) else len(r[i]) for i in range(len(r))])
else:
print(r.shape)
test()
# %% [code]
def configure_optimizers(network, optim_params, scheduler_params):
try:
optim_params = copy.copy(optim_params)
scheduler_params = copy.copy(scheduler_params)
optim_name = optim_params.pop('name')
scheduler_name = scheduler_params.pop('name')
optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None)
scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None)
if optimizer_cls is None:
raise ValueError(f"Optimizer '{optim_name}' is not available!")
optimizer = optimizer_cls(network.parameters(), **optim_params)
scheduler = None
if scheduler_params and scheduler_cls: # Chỉ tạo scheduler nếu có tham số
scheduler = scheduler_cls(optimizer, **scheduler_params)
return optimizer, scheduler
except KeyError as e:
raise ValueError(f"Missing {e} in config!!")
def freeze(self, model):
model.eval()
for param in model.parameters():
param.requires_grad = False
def unfreeze(self, model):
model.train()
for param in model.parameters():
param.requires_grad = True
def reduce_batch_size(loader, ratio=0.5):
new_bs = max(1, int(loader.batch_size * ratio))
shuffle = isinstance(loader.sampler, RandomSampler)
new_loader = DataLoader(
dataset=loader.dataset,
batch_size=new_bs,
shuffle=shuffle,
sampler=None if shuffle else loader.sampler,
num_workers=loader.num_workers,
collate_fn=loader.collate_fn,
pin_memory=loader.pin_memory,
drop_last=loader.drop_last,
timeout=loader.timeout,
worker_init_fn=loader.worker_init_fn,
multiprocessing_context=loader.multiprocessing_context,
generator=loader.generator,
prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None,
persistent_workers=loader.persistent_workers,
pin_memory_device=loader.pin_memory_device
)
return new_loader
def list_to_tuple(x):
if isinstance(x, (list, tuple)):
return tuple(list_to_tuple(i) for i in x)
return x
def fmt(x):
if isinstance(x, float):
return round(x, 5)
if isinstance(x, dict):
return {k: fmt(v) for k, v in x.items()}
if isinstance(x, list):
return [fmt(v) for v in x]
return x
class ModelEmaV3Proxy(ModelEmaV3):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
class DataParallelProxy(nn.DataParallel):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
attr = getattr(self.module, name)
if callable(attr):
def wrapper(*args, **kwargs):
return self._parallel_apply_method(name, *args, **kwargs)
return wrapper
return attr
def _parallel_apply_method(self, method_name, *inputs, **kwargs):
if not self.device_ids:
return getattr(self.module, method_name)(*inputs, **kwargs)
inputs_scattered, kwargs_scattered = self.scatter(inputs, kwargs, self.device_ids)
replicas = self.replicate(self.module, self.device_ids)
outputs = self.parallel_apply(
[getattr(replica, method_name) for replica in replicas],
inputs_scattered,
kwargs_scattered
)
return self.gather(outputs, self.output_device)
def fix_bio(tags):
fixed = []
for i, tag in enumerate(tags):
if tag.startswith('I-'):
if i == 0 or fixed[i-1] == 'O':
tag = 'B-' + tag[2:]
else:
prev_type = fixed[i-1][2:]
curr_type = tag[2:]
if prev_type != curr_type:
tag = 'B-' + curr_type
fixed.append(tag)
return fixed
def extract_entities(input_ids, start_logits, end_logits, id2label):
"""
Args:
input_ids: Tensor (B, L)
start_logits: Tensor (B, L, C)
end_logits: Tensor (B, L, C)
id2label: dict {label_id: label_name}
Returns:
List[(bidx, (input_ids[bidx, s:e+1], id2label[label_id]))]
"""
start_labels = start_logits.argmax(dim=-1) # (B, L)
end_labels = end_logits.argmax(dim=-1) # (B, L)
B, L = start_labels.shape
results = []
for bidx in range(B):
used_start = set()
used_end = set()
for s in range(L):
s_label = start_labels[bidx, s].item()
# bỏ qua nhãn O = 0
if s_label == 0:
continue
if s in used_start:
continue
nearest_e = None
# tìm end gần nhất có cùng label
for e in range(s, L):
if e in used_end:
continue
e_label = end_labels[bidx, e].item()
if e_label == s_label:
nearest_e = e
break
if nearest_e is None:
continue
used_start.add(s)
used_end.add(nearest_e)
entity_tokens = input_ids[bidx, s:nearest_e + 1].tolist()
results.append((bidx, (entity_tokens, id2label[s_label])))
return results
class Trainer:
def __init__(
self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0,
logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu',
schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True
):
self.ema_net = None
self.training_time = self._time_str_to_seconds(training_time)
self.mode = eval_mode
self.topk = topk
self.device = device
self.logging = logging if logging < epochs else 1
self.logging_file = logging_file
self.checkpoints_dir = checkpoints_dir
self.early_stopping = early_stopping
self.eval_from_ratio = eval_from_ratio
self.eval_every = eval_every
self.save_name = save_name
self.save_best = save_best
self.save_last = save_last
self.return_best = return_best
self.return_last = return_last
self.max_grad_norm = max_grad_norm
self.schedule_in_step = schedule_in_step
self.use_ema = use_ema
self.ema_from_ratio = ema_from_ratio
self.ema_decay = ema_decay
self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]]
self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0)
def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, id2label=None):
if eval_fn is None:
if self.mode == "max":
eval_fn = lambda *x: -loss_fn(*x)
else:
eval_fn = lambda *x: loss_fn(*x)
if torch.cuda.device_count() > 1:
network = DataParallelProxy(network)
network = network.to(self.device)
if not start_training_time:
start_training_time = time.time()
start_ema = int(epochs * self.ema_from_ratio)
start_eval = int(epochs * self.eval_from_ratio)
if val_loader is None:
print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!')
else:
model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc'
start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else ''
print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str)
training_log = {}
for epoch in range(start_epoch, epochs+start_epoch):
if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema:
self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
try:
train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn)
logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
logging_dict.update(train_loss_epoch_dict)
if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0:
eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, id2label)
update = self._update_best_network(eval_net, val_score, epoch)
logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update})
logging_dict.update(val_score_dict)
if not self.schedule_in_step and scheduler:
scheduler.step()
except RuntimeError as e:
if "out of memory" in str(e).lower():
print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...")
torch.cuda.empty_cache()
gc.collect()
if torch.cuda.is_available():
torch.cuda.synchronize()
print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared")
train_loader = reduce_batch_size(train_loader, ratio=0.5)
if val_loader is not None:
val_loader = reduce_batch_size(val_loader, ratio=0.5)
logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')}
else:
raise
training_log[epoch] = logging_dict
if self.is_early_stopping(epoch):
print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...')
break
if self.logging:
if epoch % self.logging == 0:
print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict))
else:
print(f'{epoch}...', end=' ')
if self._at_time_limit(start_training_time):
print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}')
break
if self.logging_file:
os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True)
with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f:
f.write(json.dumps(training_log))
if self.use_ema and self.ema_net is not None:
self._save_state_dict(self.ema_net.module)
else:
self._save_state_dict(network)
print(f'[Trainer CallBack] 📢 Kết thúc training.\n')
best_model, last_model = None, None
eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
if self.return_best :
best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict()
best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()}
if self.return_last:
last_model = eval_net.state_dict()
last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()}
del network
torch.cuda.empty_cache()
gc.collect()
return training_log, best_model, last_model
def _time_str_to_seconds(self, time_str):
days, hours, minutes, seconds = map(int, time_str.split(":"))
return days * 86400 + hours * 3600 + minutes * 60 + seconds
def _update_best_network(self, network, val_score, epoch):
topk = max(1, self.topk)
self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}])
self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk]
if val_score in [x[0] for x in self.best_stage]:
return True
return False
def is_early_stopping(self, epoch):
if self.best_stage[0][1] is None:
return False
if not self.early_stopping:
return False
return epoch - self.best_stage[0][1] >= self.early_stopping
def _at_time_limit(self, start_training_time):
return time.time() - start_training_time >= self.training_time
def _save_state_dict(self, network):
if self.topk <= 0:
return
if self.save_best:
for r in range(self.topk):
os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True)
for rank, (score, epoch, state_dict) in enumerate(self.best_stage):
if state_dict is None:
continue
state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()}
torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth')
if self.save_last:
os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True)
state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()}
torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth')
def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn):
network.train()
total_loss = 0
total_loss_dict = {}
for batch_idx, batch in enumerate(train_loader):
optimizer.zero_grad()
with torch.autocast(device_type=self.device, dtype=torch.float16):
loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn)
for k, v in loss_dict.items():
t = total_loss_dict.get(k, 0)
total_loss_dict[k] = t + v
self.grad_scaler.scale(loss).backward()
self.grad_scaler.unscale_(optimizer)
grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm)
# print(grad_norm) # Bỏ cmt dòng này để biết nên chọn max_grad_norm bằng bao nhiêu...
self.grad_scaler.step(optimizer)
self.grad_scaler.update()
if self.schedule_in_step and scheduler:
scheduler.step()
if self.use_ema and self.ema_net is not None:
self.ema_net.update(network)
total_loss += loss
return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()}
def _eval_epoch(self, network, val_loader, eval_fn, id2label):
network.eval()
total_score = 0.0
total_score_dict = {}
object_lists = None # sẽ init sau
with torch.no_grad():
for batch_idx, batch in enumerate(val_loader):
score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, id2label)
total_score += score
for k, v in score_dict.items():
t = total_score_dict.get(k, 0)
total_score_dict[k] = t + v
if objects:
if object_lists is None:
object_lists = [[] for _ in range(len(objects))]
for i, obj in enumerate(objects):
object_lists[i].append(obj.detach())
if object_lists is not None:
object_arrays = [
torch.concat(obj_list, dim=0).cpu().numpy()
for obj_list in object_lists
]
else:
object_arrays = []
return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays
def _cal_loss(self, network, batch, batch_idx, loss_fn):
# Bạn cần override _cal_loss để tính loss
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
start_labels = batch['start_labels'].to(self.device)
end_labels = batch['end_labels'].to(self.device)
start_logits, end_logits = network(input_ids, attention_mask)
loss_dict = loss_fn(
start_logits, start_labels,
end_logits, end_labels,
)
return loss_dict['total'], loss_dict
def _cal_val_score(self, network, batch, batch_idx, eval_fn, id2label):
# Bạn cần override _cal_val_score để tính val score, list bên cạnh là để trả về y hay pred gì đó (nếu cần)
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
gold_entities = batch['gold_entities']
B, _, _ = input_ids.shape
start_logits, end_logits = network(input_ids, attention_mask)
pred_ids = extract_entities(input_ids.reshape(B, -1), start_logits, end_logits, id2label)
pred_ids = list_to_tuple(pred_ids)
gold_ids = list_to_tuple(gold_entities)
score_dict = eval_fn(pred_ids, gold_ids)
return score_dict['f1'], score_dict, []
# %% [code]
class PhoBERTSpanAligner:
def __init__(self, tokenizer, max_len):
self.tokenizer = tokenizer
self.max_len = max_len
# ===== 1. Extract discontinuous spans =====
def extract_spans(self, sample):
entity_spans = []
for event in sample["entities"]:
entity_type = event["label"]
spans = [tuple(event["offset"])]
entity_spans.append({
"spans": spans,
"label": entity_type
})
return entity_spans
# ===== 2. Word offsets =====
def build_word_offsets(self, text, words):
offsets = []
pointer = 0
for word in words:
start = text.find(word, pointer)
end = start + len(word)
offsets.append((start, end))
pointer = end
return offsets
# ===== 3. Char → word =====
def char_span_to_word_span(self, word_offsets, start, end):
start_word = None
end_word = None
for i, (w_start, w_end) in enumerate(word_offsets):
if w_start <= start < w_end:
start_word = i
if w_start < end <= w_end:
end_word = i
return start_word, end_word
# ===== 4. Word → subword =====
def word_to_subword_map(self, words):
mapping = []
subword_index = 1 # <s>
for word in words:
sub_tokens = self.tokenizer.tokenize(word)
start = subword_index
end = subword_index + len(sub_tokens) - 1
mapping.append((start, end))
subword_index += len(sub_tokens)
return mapping
# ===== 5. Span → subword =====
def span_to_subword(self, word_offsets, word_subword_map, spans):
sub_spans = []
for span_start, span_end in spans:
w_start, w_end = self.char_span_to_word_span(
word_offsets, span_start, span_end
)
if w_start is None or w_end is None:
continue
sub_start = word_subword_map[w_start][0]
sub_end = word_subword_map[w_end][1]
sub_spans.append((sub_start, sub_end))
return sub_spans
def extract_valid_spans(self, sub_spans):
valid_spans = []
for s, e in sub_spans:
if s < 0 or e < 0 or s >= self.max_len or e >= self.max_len or s > e:
continue
valid_spans.append((s, e))
return valid_spans
def encode(self, sample):
text = sample["text"]
entities = self.extract_spans(sample)
# ===== 1. Word tokenize =====
words = word_tokenize(text)
sentence = " ".join(words)
# ===== 2. Mapping =====
word_offsets = self.build_word_offsets(text, words)
word_subword_map = self.word_to_subword_map(words)
# ===== 3. Tokenize FULL =====
encoding = self.tokenizer(
sentence,
max_length=self.max_len,
truncation=True,
padding="max_length",
return_tensors="pt"
)
input_ids = encoding["input_ids"][0]
attention_mask = encoding["attention_mask"][0]
# ===== 5. Convert spans =====
entities_gold_spans = []
for ent in entities:
label = ent["label"]
sub_spans = self.span_to_subword(
word_offsets,
word_subword_map,
ent["spans"]
)
valid_spans = self.extract_valid_spans(sub_spans)
if len(valid_spans) == 0:
continue
entities_gold_spans.append((tuple(valid_spans), label))
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"entities_gold_spans": entities_gold_spans,
}
def generate_candidate_spans(seq_len, max_span_len):
spans = []
for i in range(1, seq_len+1):
for j in range(i, min(i+max_span_len, seq_len+1)):
spans.append((i, j))
return spans
class KLTNDataset(Dataset):
def __init__(self, all_data, using_idxes, label2id, tokenizer, max_len, max_n_parts):
super().__init__()
self.tokenizer = tokenizer
self.aligner = PhoBERTSpanAligner(tokenizer, max_len*max_n_parts)
self.all_data = all_data
self.using_idxes = using_idxes
self.label2id = label2id
self.max_len = max_len
self.max_n_parts = max_n_parts
def __len__(self):
return len(self.using_idxes)
def __getitem__(self, idx):
ridx = self.using_idxes[idx]
sample = self.all_data[ridx]
result = self.aligner.encode(sample)
input_ids = result["input_ids"].squeeze(0)
attention_mask = result["attention_mask"].squeeze(0)
entities_gold_spans = result["entities_gold_spans"]
# Get label
gold_entities = []
start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
for spans, label in entities_gold_spans:
s, e = spans[0]
start_labels[s] = self.label2id[f'{label}']
end_labels[e] = self.label2id[f'{label}']
gold_entities.append((tuple(input_ids[s:e+1].tolist()), label))
input_ids = input_ids.reshape(self.max_n_parts, self.max_len)
attention_mask = attention_mask.reshape(self.max_n_parts, self.max_len)
n_valid_parts = math.ceil(attention_mask.sum().item() / self.max_len)
input_ids = input_ids[:n_valid_parts]
attention_mask = attention_mask[:n_valid_parts]
start_labels = start_labels[:n_valid_parts*self.max_len]
end_labels = end_labels[:n_valid_parts*self.max_len]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"start_labels": start_labels,
"end_labels": end_labels,
"gold_entities": gold_entities,
}
def _pad_batch(tensor_list, pad_value=0):
"""
tensor_list: list of tensors
mỗi tensor shape: (Nk, n_parts_i, max_len_i)
return:
padded tensor shape: (B, max_Nk, max_n_parts, max_len)
"""
# lấy max toàn batch
max_Nk = max(t.size(0) for t in tensor_list)
max_n_parts = max(t.size(1) for t in tensor_list)
max_len = max(t.size(2) for t in tensor_list)
padded = []
for t in tensor_list:
Nk, n_parts_i, max_len_i = t.shape
# pad chiều n_parts và max_len trước
if n_parts_i < max_n_parts or max_len_i < max_len:
new_t = t.new_full(
(Nk, max_n_parts, max_len),
pad_value
)
new_t[:, :n_parts_i, :max_len_i] = t
t = new_t
# pad chiều Nk
if Nk < max_Nk:
pad_tensor = t.new_full(
(max_Nk - Nk, max_n_parts, max_len),
pad_value
)
t = torch.cat([t, pad_tensor], dim=0)
padded.append(t)
return torch.stack(padded) # (B, max_Nk, max_n_parts, max_len)
def collate_fn(batch):
gold_entities = []
for bidx, b in enumerate(batch):
for entity in b['gold_entities']:
gold_entities.append([bidx, entity])
input_ids = [b["input_ids"].unsqueeze(-1) for b in batch]
attention_mask = [b["attention_mask"].unsqueeze(-1) for b in batch]
start_labels = [b["start_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
end_labels = [b["end_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
# pad theo Nk
input_ids = _pad_batch(input_ids, pad_value=0).squeeze(-1)
attention_mask = _pad_batch(attention_mask, pad_value=0).squeeze(-1)
start_labels = _pad_batch(start_labels, pad_value=-100).squeeze(-1).squeeze(-1)
end_labels = _pad_batch(end_labels, pad_value=-100).squeeze(-1).squeeze(-1)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"start_labels": start_labels,
"end_labels": end_labels,
"gold_entities": gold_entities,
}
# %% [code]
def shift_bidx(spans, batch_idx):
shifted = []
for bidx, ent in spans:
new_bidx = bidx + batch_idx * batch_size
shifted.append((new_bidx, ent))
return shifted
def refactor_entities(entities, save_dict):
i, c = [], []
for bidx, (ids, lb) in entities:
if (bidx, ids) not in i:
i.append((bidx, ids))
if (bidx, (ids, lb)) not in c:
c.append((bidx, (ids, lb)))
save_dict['Ent-I'].extend(i)
save_dict['Ent-C'].extend(c)
def test(network, state_dicts, test_loader, eval_fn, analyzer, device, id2label, tokenizer):
if torch.cuda.device_count() > 1:
network = DataParallelProxy(network)
network = network.to(device)
network.eval()
eval_types = ['Ent-I', 'Ent-C']
all_pred = {eval_type: [] for eval_type in eval_types}
all_gold = {eval_type: [] for eval_type in eval_types}
list_input_ids = []
with torch.no_grad():
for batch_idx, batch in enumerate(test_loader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
gold_entities = batch['gold_entities']
B, _, _ = input_ids.shape
list_input_ids.extend(input_ids.reshape(B, -1).tolist())
list_start_logits = []
list_end_logits = []
for sd in state_dicts:
if torch.cuda.device_count() > 1:
network.module.load_state_dict(sd)
else:
network.load_state_dict(sd)
start_logits, end_logits = network(input_ids, attention_mask)
list_start_logits.append(start_logits)
list_end_logits.append(end_logits)
ensemble_start_logits = torch.stack(list_start_logits, dim=0).mean(dim=0)
ensemble_end_logits = torch.stack(list_end_logits, dim=0).mean(dim=0)
pred_entities = extract_entities(input_ids.reshape(B, -1), ensemble_start_logits, ensemble_end_logits, id2label)
pred_entities = shift_bidx(pred_entities, batch_idx)
refactor_entities(pred_entities, all_pred)
gold_entities = shift_bidx(gold_entities, batch_idx)
refactor_entities(gold_entities, all_gold)
# ===== GLOBAL EVAL =====
final_score = {}
for eval_type in eval_types:
score = eval_fn(list_to_tuple(all_pred[eval_type]), list_to_tuple(all_gold[eval_type]))
final_score[eval_type] = score
analyze_result = analyzer.analyze(list_to_tuple(all_pred['Ent-I']), list_to_tuple(all_gold['Ent-I']))
# ===== PREDICT =====
predictions = []
for input_ids in list_input_ids:
predictions.append([tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)])
for bidx, (ids, lb) in all_pred['Ent-C']:
predictions[bidx].append((tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True), lb))
return final_score, analyze_result, predictions
# %% [code]
with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f:
data_train = json.load(f)
with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f:
data_test = json.load(f)
print('Train:', len(data_train))
print('Test:', len(data_test))
# %% [code]
entity_types = ['O'] + sorted(list(set([e['label'] for d in data_train + data_test for e in d['entities']])))
# bio_entity_type = ['O'] + [f'{prefix}-{ent}' for ent in entity_types for prefix in ['B', 'I']]
label2id = {l: i for i, l in enumerate(entity_types)}
id2label = {i: l for l, i in label2id.items()}
# %% [code]
zero_entities_idxes = []
for idx, d in enumerate(data_train):
if len(d['entities']) == 0:
zero_entities_idxes.append(idx)
n_zero_entities_samples = len(zero_entities_idxes)
n_has_entities_samples = len(data_train) - n_zero_entities_samples
random.seed(42)
k = min(int(n_has_entities_samples * zero_entities_rate), len(zero_entities_idxes))
sampled_zero_entities_idxes = random.sample(zero_entities_idxes, k)
new_data_train = []
for idx, d in enumerate(data_train):
if len(d['entities']) == 0:
if idx in sampled_zero_entities_idxes:
new_data_train.append(d)
else:
new_data_train.append(d)
data_train = new_data_train
print('Train:', len(data_train))
# %% [code]
if debug_only:
data_train = data_train[:10]
data_test = data_test[:10]
print('Train:', len(data_train))
print('Test:', len(data_test))
# %% [code]
tokenizer = AutoTokenizer.from_pretrained(backbone_model_name)
# %% [code]
print('Experiment name:', state_dict_save_name)
# %% [code]
if not test_only:
full_idxes = np.array(range(len(data_train)))
training_logs, best_models, last_models = [], [], []
start_training_time = time.time()
for seed in SEEDS:
kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed)
for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)):
if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx:
continue
set_seed(seed)
train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx]
trainset = KLTNDataset(data_train, train_idxes, label2id, tokenizer, **train_memory_params)
valset = KLTNDataset(data_train, val_idxes, label2id, tokenizer, **val_memory_params)
generator = torch.Generator()
generator.manual_seed(seed)
train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params)
val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params)
my_model = IEModel(
num_labels=len(label2id),
**model_params
)
total_params = sum(p.numel() for p in my_model.parameters())
print(f"Total params: {total_params:,}")
# optimizer, scheduler = configure_optimizers(my_model, optim_params, scheduler_params)
encoder_params = set(map(id, my_model.encoder.parameters()))
other_params = [
p for p in my_model.parameters()
if id(p) not in encoder_params
]
optimizer = optim.AdamW([
{"params": my_model.encoder.parameters(), "lr": 2e-5},
{"params": other_params}
], lr=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
loss_fn = CustomLoss(
**loss_func_params
)
eval_fn = CustomEvalFn(**eval_func_params)
trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}'
trainer = Trainer(**trainer_params)
print(f'Start Training Fold {fold_idx}...')
training_log, best_model, last_model = trainer.fit(
my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, eval_fn,
start_epoch=1, start_training_time=start_training_time, id2label=id2label
)
training_logs.append(training_log)
best_models.append(best_model)
last_models.append(last_model)
# %% [code]
def load_all_state_dicts(folder):
files = []
for file in os.listdir(folder):
if file.endswith(".pt") or file.endswith(".pth"):
m = re.search(r"f(\d+)", file) # tìm f<số>
if m:
fold = int(m.group(1))
files.append((fold, file))
# sort theo fold
files.sort(key=lambda x: x[0])
state_dicts = []
for fold, file in files:
path = os.path.join(folder, file)
print(f"Loading fold {fold}: {file}")
state_dict = torch.load(path, map_location="cpu")
state_dicts.append(state_dict)
return state_dicts
if test_only:
snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"])
get_ipython().system('rm -rf .cache .gitattributes')
best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s")
last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts")
# %% [code]
os.makedirs(f'{checkpoints_dir}/results', exist_ok=True)
testset = KLTNDataset(data_test, range(len(data_test)), label2id, tokenizer, **val_memory_params)
generator = torch.Generator()
test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params)
eval_fn = CustomEvalFn(**eval_func_params)
analyzer = SpanErrorAnalyzer()
my_model = IEModel(
num_labels=len(label2id),
**model_params
)
total_params = sum(p.numel() for p in my_model.parameters())
print(f"Total params: {total_params:,}")
# %% [code]
start_time = time.time()
result_test = None
analyze_result = None
best_score, best_analyze_result, best_pred_test = test(my_model, best_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
last_score, last_analyze_result, last_pred_test = test(my_model, last_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
result_test = {"Best model": best_score, "Last model": last_score}
analyze_result = {"Best model": best_analyze_result, "Last model": last_analyze_result}
analyze_result_sumary = {"Best model": best_analyze_result['summary'], "Last model": last_analyze_result['summary']}
pred_test = {"Best model": best_pred_test, "Last model": last_pred_test}
with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f:
json.dump(result_test, f, ensure_ascii=False, indent=2)
with open(f"{checkpoints_dir}/results/{state_dict_save_name}_error_analyze_result.json", "w", encoding="utf-8") as f:
json.dump(analyze_result, f, ensure_ascii=False, indent=2)
with open(f"{checkpoints_dir}/results/{state_dict_save_name}_pred_test.json", "w", encoding="utf-8") as f:
json.dump(pred_test, f, ensure_ascii=False, indent=2)
print('Test:', time.time() - start_time, 's --> Done!')
print(json.dumps(analyze_result_sumary, ensure_ascii=False, indent=4))
# %% [code]
best_pred_test[:10]
# %% [code]
last_pred_test[:10]
# %% [code]
def dict_to_df(data):
row_tuples = []
row_values = []
metrics = ["precision", "recall", "f1"]
# Lấy model đầu tiên
first_model = next(iter(data.values()))
# eval_keys
eval_keys = list(first_model.keys())
for eval_key in eval_keys:
row_tuples.append(eval_key)
row = {}
for model_name, model_data in data.items():
for metric in metrics:
row[(model_name, metric)] = model_data[eval_key][metric]
row_values.append(row)
# ===== DataFrame =====
df = pd.DataFrame(row_values)
# MultiIndex columns
df.columns = pd.MultiIndex.from_tuples(df.columns)
# Index
df.index = pd.Index(row_tuples, name="evaluation")
# ===== Sort =====
sort_keys = []
if ("Best model", "f1") in df.columns:
sort_keys.append(("Best model", "f1"))
if ("Last model", "f1") in df.columns:
sort_keys.append(("Last model", "f1"))
if sort_keys:
df = df.sort_values(by=sort_keys, ascending=False)
return df
result_test_df = dict_to_df(result_test)
result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx")
result_test_df
# %% [code]
key = ("Best model", "f1")
result_test_df_best = result_test_df.sort_values(by=key, ascending=False).groupby(level="evaluation").head(1)
result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx")
result_test_df_best
# %% [code]
def get_avg_best_score(logs):
return float(np.mean([list(log.values())[-1]['best_score'] for log in logs]))
def get_avg_log(logs, epochs):
avg_log = {}
for epoch in range(1, epochs + 1):
val_score = 0.0
train_loss = 0.0
n_eval = 0
for idx in range(len(logs)):
log = logs[idx].get(epoch, logs[idx].get(str(epoch)))
if log is None:
continue
val_score += log.get('val_score', 0.0)
train_loss += log.get('train_loss', 0.0)
n_eval += 1
if n_eval == 0:
continue
avg_log[epoch] = {
'train_loss': train_loss / n_eval,
'val_score': val_score / n_eval if val_score != 0 else float('inf')
}
return avg_log
def parse_label_key(label: str):
try:
first = float(label.split('_', 1)[0]) # số đầu: trước dấu _
last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0])
return first, last
except:
return (0, 0)
def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)):
fig, axes = plt.subplots(1, 2, figsize=figsize)
# ===== Plot Train Loss =====
for name, log in logs_dict.items():
epochs = sorted(log.keys())
train_loss = [log[e]['train_loss'] for e in epochs]
axes[0].plot(epochs, train_loss, label=name)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Train Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True)
# ===== Plot Validation Score =====
for name, log in logs_dict.items():
epochs = sorted(log.keys())
val_score = [log[e]['val_score'] for e in epochs]
axes[1].plot(epochs, val_score, label=name)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Validation Score')
axes[1].set_title('Validation Score')
axes[1].grid(True)
# ===== Shared Legend =====
handles, labels = axes[0].get_legend_handles_labels()
pairs = list(zip(handles, labels))
pairs_sorted = sorted(
pairs,
key=lambda x: parse_label_key(x[1])
)
handles_sorted, labels_sorted = zip(*pairs_sorted)
axes[0].legend(
handles_sorted,
labels_sorted,
loc='center left',
bbox_to_anchor=(1.01, 0.5),
borderaxespad=0.
)
plt.tight_layout(rect=[0, 0, 1, 1])
if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
# %% [code]
if not test_only:
snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*.json"])
get_ipython().system('rm -rf .cache .gitattributes')
# %% [code]
if not test_only:
experiments = {}
for experiment in os.listdir(pretrained_dir):
if '.virtual_documents' in experiment:
continue
experiment_logs = []
try:
for seed in SEEDS:
for fold_idx in range(nfolds):
with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f:
experiment_log = json.load(f)
experiment_logs.append(experiment_log)
except:
pass
experiments[experiment] = get_avg_log(experiment_logs, 1000)
experiments[state_dict_save_name] = get_avg_log(training_logs, 1000)
# %% [code]
if not test_only:
score = get_avg_best_score(training_logs)
state_dict_save_name, score
# %% [code]
if not test_only:
plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5))