| |
| get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch] pytorch-crf') |
|
|
| |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, TensorDataset, DataLoader |
| import torch.nn.functional as F |
| import albumentations as albu |
| from transformers import AutoTokenizer, AutoModel |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from positional_encodings.torch_encodings import PositionalEncoding1D |
| from torchcrf import CRF |
|
|
| from sklearn.metrics import f1_score |
| from sklearn.preprocessing import MinMaxScaler, StandardScaler |
| from scipy.spatial.transform import Rotation as R |
| from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold |
| from sklearn.metrics import precision_recall_fscore_support |
| from timm.utils import ModelEmaV3 |
| import timm |
|
|
| import os |
| import gc |
| import json |
| from pathlib import Path |
| import pickle |
| from tqdm.auto import tqdm |
| import copy |
| import numpy as np |
| import pandas as pd |
| import polars as pl |
| from PIL import Image |
| import time |
| from tqdm import tqdm |
| from matplotlib import pyplot as plt |
| import seaborn as sns |
| from multiprocessing import Manager as MemoryManager |
| from functools import lru_cache |
| import shutil |
| import glob |
| import cv2 |
| import random |
| import re |
| import joblib |
| import math |
| from huggingface_hub import HfApi, snapshot_download |
| import evaluate |
| from underthesea import word_tokenize as vi_tokenize_tool |
| import spacy |
| en_tokenize_tool = spacy.load("en_core_web_sm") |
| from collections import defaultdict, Counter |
|
|
| |
| |
| SEEDS = [26092004] |
| topk = 1 |
| nfolds = 5 |
| only_fold_idx = 0 |
| test_only = 0 |
| debug_only = 0 |
|
|
| |
| dataset = 'kltn/only_entities' |
| root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' |
| train_dir = f'{root_dir}' |
| |
| test_dir = f'{root_dir}' |
|
|
| |
|
|
| |
| epochs = 18 if not debug_only else 2 |
| batch_size = 32 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| repo_name = 'SS3M/kltn-experiments' |
| state_dict_save_name = "1_pointer_base_entities_4" |
| checkpoints_dir = state_dict_save_name |
| pretrained_dir = "/kaggle/working" |
| os.makedirs(f'{checkpoints_dir}', exist_ok=True) |
|
|
| backbone_model_name = "bert-base-uncased" if dataset in ["conll003", "ontonotes"] else "vinai/phobert-base" |
| word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == dataset in ["conll003", "ontonotes"] else vi_tokenize_tool(text) |
| max_len_dict = { |
| 'kltn/raw': 256, |
| 'kltn/only_entities': 68, |
| 'conll003': 46, |
| 'ontonotes': 61, |
| 'phoner': 68, |
| 'vietbio': 125, |
| 'vietmed': 36, |
| 'vimed': 100, |
| } |
| zero_entities_rate_dict = { |
| 'kltn/raw': 1000, |
| 'kltn/only_entities': 0.2, |
| 'conll003': 1000, |
| 'ontonotes': 1000, |
| 'phoner': 1000, |
| 'vietbio': 1000, |
| 'vietmed': 1000, |
| 'vimed': 1000, |
| } |
|
|
| max_len = max_len_dict[dataset] |
| max_n_parts = 1 |
| max_span_len = 10 |
| zero_entities_rate = zero_entities_rate_dict[dataset] |
|
|
| |
| trainer_params = { |
| "training_time": "00:11:30:00", |
| "eval_mode": "max", |
| "topk": topk, |
| "save_name": state_dict_save_name, |
| "save_best": True, |
| "save_last": True, |
| "device": device, |
| "logging": True, |
| "logging_file": True, |
| "checkpoints_dir": checkpoints_dir, |
| "early_stopping": 30, |
| "eval_from_ratio": 0.4, |
| "eval_every": 1, |
| "schedule_in_step": False, |
| "use_ema": True, |
| "ema_from_ratio": 0.3, |
| "ema_decay": 0.9995, |
| "max_grad_norm": 200.0, |
| "return_best": True, |
| "return_last": True, |
| } |
|
|
| |
| train_memory_params = { |
| 'max_len': max_len, |
| 'max_n_parts': max_n_parts, |
| } |
| val_memory_params = { |
| 'max_len': max_len, |
| 'max_n_parts': max_n_parts, |
| } |
|
|
| |
| def seed_worker(worker_id): |
| worker_seed = torch.initial_seed() % 2**32 |
| np.random.seed(worker_seed) |
| random.seed(worker_seed) |
| |
| train_loader_params = { |
| 'batch_size': batch_size, |
| 'shuffle': True, |
| 'pin_memory':True, |
| 'num_workers': 2, |
| 'drop_last': False, |
| 'worker_init_fn': seed_worker, |
| 'persistent_workers': False, |
| } |
| val_loader_params = { |
| 'batch_size': batch_size, |
| 'shuffle': False, |
| 'pin_memory':True, |
| 'num_workers': 1, |
| 'drop_last': False, |
| 'worker_init_fn': seed_worker, |
| 'persistent_workers': False, |
| } |
|
|
| |
| model_params = { |
| 'backbone_model_name': backbone_model_name, |
| } |
|
|
| |
| loss_func_params = { |
| 'lambda_ce': 1.0, |
| } |
| eval_func_params = {} |
|
|
| |
| optim_params = { |
| 'name': 'AdamW', |
| 'lr': 1e-4, |
| 'weight_decay': 1e-4, |
| } |
| scheduler_params = { |
| 'name': 'CosineAnnealingLR', |
| 'T_max': 20, |
| 'eta_min': 1e-6 |
| } |
|
|
| |
| def set_seed(seed=42): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.use_deterministic_algorithms(False) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
| |
| class CustomLoss(nn.Module): |
| def __init__(self, lambda_ce=1.0): |
| super().__init__() |
| self.lambda_ce = lambda_ce |
| self.ce = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
| def forward( |
| self, |
| start_logits, start_labels, |
| end_logits, end_labels, |
| ): |
| device = start_logits.device |
|
|
| |
| B, L, C = start_logits.shape |
| |
| start_logits_flat = start_logits.view(B * L, C) |
| start_labels_flat = start_labels.view(-1) |
| start_loss = self.ce(start_logits_flat, start_labels_flat) |
| |
| end_logits_flat = end_logits.view(B * L, C) |
| end_labels_flat = end_labels.view(-1) |
| end_loss = self.ce(end_logits_flat, end_labels_flat) |
|
|
| return { |
| "total": start_loss + end_loss, |
| "start_loss": start_loss, |
| "end_loss": end_loss, |
| } |
|
|
| |
| |
|
|
| |
| class CustomEvalFn(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def compute_f1(self, tp, fp, fn): |
| precision = tp / (tp + fp + 1e-8) |
| recall = tp / (tp + fn + 1e-8) |
| f1 = 2 * precision * recall / (precision + recall + 1e-8) |
| return precision, recall, f1 |
|
|
| def forward(self, pred, gold): |
| pred_set = set(pred) |
| gold_set = set(gold) |
|
|
| tp = len(pred_set & gold_set) |
| fp = len(pred_set - gold_set) |
| fn = len(gold_set - pred_set) |
|
|
| precision, recall, f1 = self.compute_f1(tp, fp, fn) |
|
|
| return { |
| f"precision": precision, |
| f"recall": recall, |
| f"f1": f1, |
| } |
|
|
| class SpanErrorAnalyzer: |
| def __init__(self, pad_token_id=0): |
| self.pad_token_id = pad_token_id |
|
|
| |
| def _to_set(self, data): |
| """ |
| data: list of (b, tuple(ids)) |
| -> dict[b] = set(tuple(ids)) |
| """ |
| res = defaultdict(set) |
| for b, ids in data: |
| ids = tuple([i for i in ids if i != self.pad_token_id]) |
| if len(ids) > 0: |
| res[b].add(ids) |
| return res |
|
|
| def _iou(self, a, b): |
| """ |
| a, b: tuple(ids) |
| """ |
| set_a, set_b = set(a), set(b) |
| inter = len(set_a & set_b) |
| union = len(set_a | set_b) |
| if union == 0: |
| return 0.0 |
| return inter / union |
|
|
| def _boundary_error(self, pred, gold): |
| """ |
| đo lệch boundary dựa trên overlap prefix/suffix |
| """ |
| |
| left = 0 |
| for i in range(min(len(pred), len(gold))): |
| if pred[i] == gold[i]: |
| left += 1 |
| else: |
| break |
|
|
| |
| right = 0 |
| for i in range(1, min(len(pred), len(gold)) + 1): |
| if pred[-i] == gold[-i]: |
| right += 1 |
| else: |
| break |
|
|
| return { |
| "left_match": left, |
| "right_match": right, |
| "pred_len": len(pred), |
| "gold_len": len(gold), |
| } |
|
|
| |
| def analyze(self, preds, golds): |
| pred_map = self._to_set(preds) |
| gold_map = self._to_set(golds) |
|
|
| all_batches = set(pred_map.keys()) | set(gold_map.keys()) |
|
|
| stats = Counter() |
|
|
| detailed_errors = [] |
|
|
| for b in all_batches: |
| pset = pred_map.get(b, set()) |
| gset = gold_map.get(b, set()) |
|
|
| matched_gold = set() |
|
|
| |
| for p in pset: |
| if p in gset: |
| stats["exact_match"] += 1 |
| matched_gold.add(p) |
| else: |
| |
| best_iou = 0 |
| best_g = None |
|
|
| for g in gset: |
| iou = self._iou(p, g) |
| if iou > best_iou: |
| best_iou = iou |
| best_g = g |
|
|
| if best_iou > 0: |
| stats["partial_match"] += 1 |
|
|
| boundary = self._boundary_error(p, best_g) |
|
|
| detailed_errors.append({ |
| "type": "boundary_error", |
| "batch": b, |
| "pred": p, |
| "gold": best_g, |
| "iou": best_iou, |
| **boundary |
| }) |
| else: |
| if b not in gold_map: |
| stats["no_event_sample"] += 1 |
| err_type = "no_event_sample" |
| else: |
| stats["completely_wrong"] += 1 |
| err_type = "completely_wrong" |
| |
| detailed_errors.append({ |
| "type": err_type, |
| "batch": b, |
| "pred": p |
| }) |
|
|
| |
| for g in gset: |
| if g not in matched_gold: |
| |
| overlap = any(self._iou(p, g) > 0 for p in pset) |
|
|
| if overlap: |
| stats["miss_with_overlap"] += 1 |
| else: |
| stats["miss"] += 1 |
|
|
| detailed_errors.append({ |
| "type": "miss", |
| "batch": b, |
| "gold": g |
| }) |
|
|
| return { |
| "summary": { |
| "exact_match": (stats["exact_match"], stats["exact_match"] / len(preds)), |
| "partial_match": (stats["partial_match"], stats["partial_match"] / len(preds)), |
| "no_event_sample": (stats["no_event_sample"], stats["no_event_sample"] / len(preds)), |
| "completely_wrong": (stats["completely_wrong"], stats["completely_wrong"] / len(preds)), |
| "miss": (stats["miss"], stats["miss"] / len(golds)), |
| "miss_with_overlap": (stats["miss_with_overlap"], stats["miss_with_overlap"] / len(golds)), |
| }, |
| "details": detailed_errors |
| } |
|
|
| |
| |
| class MLP(nn.Module): |
| def __init__(self, in_size, hid_size, out_size): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(in_size, hid_size), |
| nn.ReLU(), |
| nn.Linear(hid_size, out_size) |
| ) |
|
|
| def forward(self, x): |
| return self.mlp(x) |
|
|
| class IEModel(nn.Module): |
| def __init__(self, backbone_model_name, num_labels): |
| super().__init__() |
| self.encoder = AutoModel.from_pretrained(backbone_model_name) |
| hidden_size = self.encoder.config.hidden_size |
|
|
| self.start_classifier = MLP(hidden_size, hidden_size, num_labels) |
| self.end_classifier = MLP(hidden_size, hidden_size, num_labels) |
|
|
| def encode(self, input_ids, attention_mask): |
| B, n_parts, L = input_ids.shape |
| input_ids = input_ids.view(-1, L) |
| attention_mask = attention_mask.view(-1, L) |
| |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| hidden_states = outputs.last_hidden_state |
| |
| hidden_states = hidden_states.view(B, n_parts, L, -1).reshape(B, n_parts*L, -1) |
| return hidden_states |
|
|
| def get_logits(self, hidden_states): |
| start_logits = self.start_classifier(hidden_states) |
| end_logits = self.end_classifier(hidden_states) |
| return start_logits, end_logits |
|
|
| def forward(self, input_ids, attention_mask, labels=None): |
| hidden_states = self.encode(input_ids, attention_mask) |
| start_logits, end_logits = self.get_logits(hidden_states) |
| return start_logits, end_logits |
|
|
| def test(): |
| model = nn.DataParallel(IEModel(backbone_model_name, 7)).to(device) |
| model.eval() |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Total params: {total_params:,}") |
|
|
| vocab_size = model.module.encoder.config.vocab_size |
| max_len = model.module.encoder.config.max_position_embeddings |
|
|
| bz = 32 |
| i = torch.randint(0, vocab_size, (bz, 5, 10)).to(device) |
| a = torch.ones(bz, 5, 10).to(device) |
| g = torch.ones(bz, 3, 2, dtype=torch.long).to(device) |
|
|
| with torch.no_grad(): |
| r = model(i, a) |
|
|
| if type(r) == tuple: |
| print([r[i].shape if type(r[i]) == type(torch.Tensor()) else len(r[i]) for i in range(len(r))]) |
| else: |
| print(r.shape) |
|
|
| test() |
|
|
| |
| def configure_optimizers(network, optim_params, scheduler_params): |
| try: |
| optim_params = copy.copy(optim_params) |
| scheduler_params = copy.copy(scheduler_params) |
| |
| optim_name = optim_params.pop('name') |
| scheduler_name = scheduler_params.pop('name') |
| |
| optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None) |
| scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None) |
| |
| if optimizer_cls is None: |
| raise ValueError(f"Optimizer '{optim_name}' is not available!") |
| |
| optimizer = optimizer_cls(network.parameters(), **optim_params) |
| |
| scheduler = None |
| if scheduler_params and scheduler_cls: |
| scheduler = scheduler_cls(optimizer, **scheduler_params) |
| |
| return optimizer, scheduler |
|
|
| except KeyError as e: |
| raise ValueError(f"Missing {e} in config!!") |
| |
| def freeze(self, model): |
| model.eval() |
| for param in model.parameters(): |
| param.requires_grad = False |
|
|
| def unfreeze(self, model): |
| model.train() |
| for param in model.parameters(): |
| param.requires_grad = True |
|
|
| def reduce_batch_size(loader, ratio=0.5): |
| new_bs = max(1, int(loader.batch_size * ratio)) |
|
|
| shuffle = isinstance(loader.sampler, RandomSampler) |
|
|
| new_loader = DataLoader( |
| dataset=loader.dataset, |
| batch_size=new_bs, |
| shuffle=shuffle, |
| sampler=None if shuffle else loader.sampler, |
| num_workers=loader.num_workers, |
| collate_fn=loader.collate_fn, |
| pin_memory=loader.pin_memory, |
| drop_last=loader.drop_last, |
| timeout=loader.timeout, |
| worker_init_fn=loader.worker_init_fn, |
| multiprocessing_context=loader.multiprocessing_context, |
| generator=loader.generator, |
| prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None, |
| persistent_workers=loader.persistent_workers, |
| pin_memory_device=loader.pin_memory_device |
| ) |
|
|
| return new_loader |
|
|
| def list_to_tuple(x): |
| if isinstance(x, (list, tuple)): |
| return tuple(list_to_tuple(i) for i in x) |
| return x |
|
|
| def fmt(x): |
| if isinstance(x, float): |
| return round(x, 5) |
| if isinstance(x, dict): |
| return {k: fmt(v) for k, v in x.items()} |
| if isinstance(x, list): |
| return [fmt(v) for v in x] |
| return x |
| |
| class ModelEmaV3Proxy(ModelEmaV3): |
| def __getattr__(self, name): |
| try: |
| return super().__getattr__(name) |
| except AttributeError: |
| return getattr(self.module, name) |
| |
| class DataParallelProxy(nn.DataParallel): |
| def __getattr__(self, name): |
| try: |
| return super().__getattr__(name) |
| except AttributeError: |
| attr = getattr(self.module, name) |
|
|
| if callable(attr): |
| def wrapper(*args, **kwargs): |
| return self._parallel_apply_method(name, *args, **kwargs) |
| return wrapper |
|
|
| return attr |
|
|
| def _parallel_apply_method(self, method_name, *inputs, **kwargs): |
| if not self.device_ids: |
| return getattr(self.module, method_name)(*inputs, **kwargs) |
|
|
| inputs_scattered, kwargs_scattered = self.scatter(inputs, kwargs, self.device_ids) |
|
|
| replicas = self.replicate(self.module, self.device_ids) |
|
|
| outputs = self.parallel_apply( |
| [getattr(replica, method_name) for replica in replicas], |
| inputs_scattered, |
| kwargs_scattered |
| ) |
|
|
| return self.gather(outputs, self.output_device) |
|
|
| def fix_bio(tags): |
| fixed = [] |
| |
| for i, tag in enumerate(tags): |
| if tag.startswith('I-'): |
| if i == 0 or fixed[i-1] == 'O': |
| tag = 'B-' + tag[2:] |
| else: |
| prev_type = fixed[i-1][2:] |
| curr_type = tag[2:] |
| if prev_type != curr_type: |
| tag = 'B-' + curr_type |
| fixed.append(tag) |
| |
| return fixed |
|
|
| def extract_entities(input_ids, start_logits, end_logits, id2label): |
| """ |
| Args: |
| input_ids: Tensor (B, L) |
| start_logits: Tensor (B, L, C) |
| end_logits: Tensor (B, L, C) |
| id2label: dict {label_id: label_name} |
| |
| Returns: |
| List[(bidx, (input_ids[bidx, s:e+1], id2label[label_id]))] |
| """ |
|
|
| start_labels = start_logits.argmax(dim=-1) |
| end_labels = end_logits.argmax(dim=-1) |
|
|
| B, L = start_labels.shape |
|
|
| results = [] |
|
|
| for bidx in range(B): |
|
|
| used_start = set() |
| used_end = set() |
|
|
| for s in range(L): |
|
|
| s_label = start_labels[bidx, s].item() |
|
|
| |
| if s_label == 0: |
| continue |
|
|
| if s in used_start: |
| continue |
|
|
| nearest_e = None |
|
|
| |
| for e in range(s, L): |
|
|
| if e in used_end: |
| continue |
|
|
| e_label = end_labels[bidx, e].item() |
|
|
| if e_label == s_label: |
| nearest_e = e |
| break |
|
|
| if nearest_e is None: |
| continue |
|
|
| used_start.add(s) |
| used_end.add(nearest_e) |
|
|
| entity_tokens = input_ids[bidx, s:nearest_e + 1].tolist() |
|
|
| results.append((bidx, (entity_tokens, id2label[s_label]))) |
|
|
| return results |
| |
| class Trainer: |
| def __init__( |
| self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0, |
| logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu', |
| schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True |
| ): |
| self.ema_net = None |
| |
| self.training_time = self._time_str_to_seconds(training_time) |
| self.mode = eval_mode |
| self.topk = topk |
| self.device = device |
| self.logging = logging if logging < epochs else 1 |
| self.logging_file = logging_file |
| self.checkpoints_dir = checkpoints_dir |
| self.early_stopping = early_stopping |
| self.eval_from_ratio = eval_from_ratio |
| self.eval_every = eval_every |
| self.save_name = save_name |
| self.save_best = save_best |
| self.save_last = save_last |
| self.return_best = return_best |
| self.return_last = return_last |
| self.max_grad_norm = max_grad_norm |
| self.schedule_in_step = schedule_in_step |
| self.use_ema = use_ema |
| self.ema_from_ratio = ema_from_ratio |
| self.ema_decay = ema_decay |
| |
| self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]] |
| self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0) |
|
|
| def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, id2label=None): |
| if eval_fn is None: |
| if self.mode == "max": |
| eval_fn = lambda *x: -loss_fn(*x) |
| else: |
| eval_fn = lambda *x: loss_fn(*x) |
|
|
| if torch.cuda.device_count() > 1: |
| network = DataParallelProxy(network) |
| network = network.to(self.device) |
| |
| if not start_training_time: |
| start_training_time = time.time() |
| |
| start_ema = int(epochs * self.ema_from_ratio) |
| start_eval = int(epochs * self.eval_from_ratio) |
| |
| if val_loader is None: |
| print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!') |
| else: |
| model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc' |
| start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else '' |
| print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str) |
|
|
| training_log = {} |
| for epoch in range(start_epoch, epochs+start_epoch): |
| if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema: |
| self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device) |
| |
| try: |
| train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn) |
| logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch} |
| logging_dict.update(train_loss_epoch_dict) |
| |
| if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0: |
| eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network |
| |
| val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, id2label) |
| update = self._update_best_network(eval_net, val_score, epoch) |
| logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update}) |
| logging_dict.update(val_score_dict) |
| if not self.schedule_in_step and scheduler: |
| scheduler.step() |
| |
| except RuntimeError as e: |
| if "out of memory" in str(e).lower(): |
| print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...") |
| torch.cuda.empty_cache() |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared") |
| |
| train_loader = reduce_batch_size(train_loader, ratio=0.5) |
| if val_loader is not None: |
| val_loader = reduce_batch_size(val_loader, ratio=0.5) |
| |
| logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')} |
| else: |
| raise |
| |
| training_log[epoch] = logging_dict |
| if self.is_early_stopping(epoch): |
| print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...') |
| break |
| if self.logging: |
| if epoch % self.logging == 0: |
| print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict)) |
| else: |
| print(f'{epoch}...', end=' ') |
| |
| if self._at_time_limit(start_training_time): |
| print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}') |
| break |
| |
| if self.logging_file: |
| os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True) |
| with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f: |
| f.write(json.dumps(training_log)) |
| |
| if self.use_ema and self.ema_net is not None: |
| self._save_state_dict(self.ema_net.module) |
| else: |
| self._save_state_dict(network) |
| print(f'[Trainer CallBack] 📢 Kết thúc training.\n') |
|
|
| best_model, last_model = None, None |
| eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network |
| if self.return_best : |
| best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict() |
| best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()} |
| if self.return_last: |
| last_model = eval_net.state_dict() |
| last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()} |
|
|
| del network |
| torch.cuda.empty_cache() |
| gc.collect() |
| return training_log, best_model, last_model |
|
|
| def _time_str_to_seconds(self, time_str): |
| days, hours, minutes, seconds = map(int, time_str.split(":")) |
| return days * 86400 + hours * 3600 + minutes * 60 + seconds |
|
|
| def _update_best_network(self, network, val_score, epoch): |
| topk = max(1, self.topk) |
| self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}]) |
| self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk] |
| if val_score in [x[0] for x in self.best_stage]: |
| return True |
| return False |
|
|
| def is_early_stopping(self, epoch): |
| if self.best_stage[0][1] is None: |
| return False |
| if not self.early_stopping: |
| return False |
| return epoch - self.best_stage[0][1] >= self.early_stopping |
|
|
| def _at_time_limit(self, start_training_time): |
| return time.time() - start_training_time >= self.training_time |
|
|
| def _save_state_dict(self, network): |
| if self.topk <= 0: |
| return |
|
|
| if self.save_best: |
| for r in range(self.topk): |
| os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True) |
| |
| for rank, (score, epoch, state_dict) in enumerate(self.best_stage): |
| if state_dict is None: |
| continue |
| state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()} |
| torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth') |
| if self.save_last: |
| os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True) |
| state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()} |
| torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth') |
| |
| def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn): |
| network.train() |
| total_loss = 0 |
| total_loss_dict = {} |
| for batch_idx, batch in enumerate(train_loader): |
| optimizer.zero_grad() |
| with torch.autocast(device_type=self.device, dtype=torch.float16): |
| loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn) |
|
|
| for k, v in loss_dict.items(): |
| t = total_loss_dict.get(k, 0) |
| total_loss_dict[k] = t + v |
| self.grad_scaler.scale(loss).backward() |
| self.grad_scaler.unscale_(optimizer) |
| grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm) |
| |
| self.grad_scaler.step(optimizer) |
| self.grad_scaler.update() |
| if self.schedule_in_step and scheduler: |
| scheduler.step() |
| if self.use_ema and self.ema_net is not None: |
| self.ema_net.update(network) |
| total_loss += loss |
| return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()} |
|
|
| def _eval_epoch(self, network, val_loader, eval_fn, id2label): |
| network.eval() |
| total_score = 0.0 |
| total_score_dict = {} |
| object_lists = None |
| |
| with torch.no_grad(): |
| for batch_idx, batch in enumerate(val_loader): |
| score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, id2label) |
| total_score += score |
|
|
| for k, v in score_dict.items(): |
| t = total_score_dict.get(k, 0) |
| total_score_dict[k] = t + v |
| |
| if objects: |
| if object_lists is None: |
| object_lists = [[] for _ in range(len(objects))] |
| |
| for i, obj in enumerate(objects): |
| object_lists[i].append(obj.detach()) |
| |
| if object_lists is not None: |
| object_arrays = [ |
| torch.concat(obj_list, dim=0).cpu().numpy() |
| for obj_list in object_lists |
| ] |
| else: |
| object_arrays = [] |
| |
| return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays |
|
|
| def _cal_loss(self, network, batch, batch_idx, loss_fn): |
| |
| input_ids = batch['input_ids'].to(self.device) |
| attention_mask = batch['attention_mask'].to(self.device) |
| start_labels = batch['start_labels'].to(self.device) |
| end_labels = batch['end_labels'].to(self.device) |
| |
| start_logits, end_logits = network(input_ids, attention_mask) |
| |
| loss_dict = loss_fn( |
| start_logits, start_labels, |
| end_logits, end_labels, |
| ) |
| return loss_dict['total'], loss_dict |
|
|
| def _cal_val_score(self, network, batch, batch_idx, eval_fn, id2label): |
| |
| input_ids = batch['input_ids'].to(self.device) |
| attention_mask = batch['attention_mask'].to(self.device) |
| gold_entities = batch['gold_entities'] |
|
|
| B, _, _ = input_ids.shape |
|
|
| start_logits, end_logits = network(input_ids, attention_mask) |
| |
| pred_ids = extract_entities(input_ids.reshape(B, -1), start_logits, end_logits, id2label) |
| pred_ids = list_to_tuple(pred_ids) |
| |
| gold_ids = list_to_tuple(gold_entities) |
| |
| score_dict = eval_fn(pred_ids, gold_ids) |
| return score_dict['f1'], score_dict, [] |
|
|
| |
| class PhoBERTSpanAligner: |
| def __init__(self, tokenizer, max_len): |
| self.tokenizer = tokenizer |
| self.max_len = max_len |
|
|
| |
| def extract_spans(self, sample): |
| entity_spans = [] |
| |
| for event in sample["entities"]: |
| entity_type = event["label"] |
| spans = [tuple(event["offset"])] |
| entity_spans.append({ |
| "spans": spans, |
| "label": entity_type |
| }) |
| |
| return entity_spans |
|
|
| |
| def build_word_offsets(self, text, words): |
| offsets = [] |
| pointer = 0 |
|
|
| for word in words: |
| start = text.find(word, pointer) |
| end = start + len(word) |
| offsets.append((start, end)) |
| pointer = end |
|
|
| return offsets |
|
|
| |
| def char_span_to_word_span(self, word_offsets, start, end): |
| start_word = None |
| end_word = None |
|
|
| for i, (w_start, w_end) in enumerate(word_offsets): |
| if w_start <= start < w_end: |
| start_word = i |
| if w_start < end <= w_end: |
| end_word = i |
|
|
| return start_word, end_word |
|
|
| |
| def word_to_subword_map(self, words): |
| mapping = [] |
| subword_index = 1 |
|
|
| for word in words: |
| sub_tokens = self.tokenizer.tokenize(word) |
| start = subword_index |
| end = subword_index + len(sub_tokens) - 1 |
| mapping.append((start, end)) |
| subword_index += len(sub_tokens) |
|
|
| return mapping |
|
|
| |
| def span_to_subword(self, word_offsets, word_subword_map, spans): |
| sub_spans = [] |
|
|
| for span_start, span_end in spans: |
| w_start, w_end = self.char_span_to_word_span( |
| word_offsets, span_start, span_end |
| ) |
| if w_start is None or w_end is None: |
| continue |
|
|
| sub_start = word_subword_map[w_start][0] |
| sub_end = word_subword_map[w_end][1] |
| sub_spans.append((sub_start, sub_end)) |
|
|
| return sub_spans |
|
|
| def extract_valid_spans(self, sub_spans): |
| valid_spans = [] |
| for s, e in sub_spans: |
| if s < 0 or e < 0 or s >= self.max_len or e >= self.max_len or s > e: |
| continue |
| valid_spans.append((s, e)) |
| return valid_spans |
|
|
| def encode(self, sample): |
| text = sample["text"] |
| entities = self.extract_spans(sample) |
| |
| |
| words = word_tokenize(text) |
| sentence = " ".join(words) |
| |
| |
| word_offsets = self.build_word_offsets(text, words) |
| word_subword_map = self.word_to_subword_map(words) |
| |
| |
| encoding = self.tokenizer( |
| sentence, |
| max_length=self.max_len, |
| truncation=True, |
| padding="max_length", |
| return_tensors="pt" |
| ) |
| input_ids = encoding["input_ids"][0] |
| attention_mask = encoding["attention_mask"][0] |
| |
| |
| entities_gold_spans = [] |
| |
| for ent in entities: |
| label = ent["label"] |
| |
| sub_spans = self.span_to_subword( |
| word_offsets, |
| word_subword_map, |
| ent["spans"] |
| ) |
| valid_spans = self.extract_valid_spans(sub_spans) |
| if len(valid_spans) == 0: |
| continue |
| entities_gold_spans.append((tuple(valid_spans), label)) |
| |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "entities_gold_spans": entities_gold_spans, |
| } |
|
|
| def generate_candidate_spans(seq_len, max_span_len): |
| spans = [] |
| for i in range(1, seq_len+1): |
| for j in range(i, min(i+max_span_len, seq_len+1)): |
| spans.append((i, j)) |
| return spans |
| |
| class KLTNDataset(Dataset): |
| def __init__(self, all_data, using_idxes, label2id, tokenizer, max_len, max_n_parts): |
| super().__init__() |
| self.tokenizer = tokenizer |
| self.aligner = PhoBERTSpanAligner(tokenizer, max_len*max_n_parts) |
| self.all_data = all_data |
| self.using_idxes = using_idxes |
| self.label2id = label2id |
| self.max_len = max_len |
| self.max_n_parts = max_n_parts |
|
|
| def __len__(self): |
| return len(self.using_idxes) |
|
|
| def __getitem__(self, idx): |
| ridx = self.using_idxes[idx] |
| sample = self.all_data[ridx] |
| result = self.aligner.encode(sample) |
| |
| input_ids = result["input_ids"].squeeze(0) |
| attention_mask = result["attention_mask"].squeeze(0) |
| entities_gold_spans = result["entities_gold_spans"] |
|
|
| |
| gold_entities = [] |
| start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100) |
| end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100) |
| for spans, label in entities_gold_spans: |
| s, e = spans[0] |
| |
| start_labels[s] = self.label2id[f'{label}'] |
| end_labels[e] = self.label2id[f'{label}'] |
| |
| gold_entities.append((tuple(input_ids[s:e+1].tolist()), label)) |
| |
| input_ids = input_ids.reshape(self.max_n_parts, self.max_len) |
| attention_mask = attention_mask.reshape(self.max_n_parts, self.max_len) |
|
|
| n_valid_parts = math.ceil(attention_mask.sum().item() / self.max_len) |
| input_ids = input_ids[:n_valid_parts] |
| attention_mask = attention_mask[:n_valid_parts] |
| start_labels = start_labels[:n_valid_parts*self.max_len] |
| end_labels = end_labels[:n_valid_parts*self.max_len] |
| |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "start_labels": start_labels, |
| "end_labels": end_labels, |
| "gold_entities": gold_entities, |
| } |
|
|
| def _pad_batch(tensor_list, pad_value=0): |
| """ |
| tensor_list: list of tensors |
| mỗi tensor shape: (Nk, n_parts_i, max_len_i) |
| |
| return: |
| padded tensor shape: (B, max_Nk, max_n_parts, max_len) |
| """ |
|
|
| |
| max_Nk = max(t.size(0) for t in tensor_list) |
| max_n_parts = max(t.size(1) for t in tensor_list) |
| max_len = max(t.size(2) for t in tensor_list) |
|
|
| padded = [] |
|
|
| for t in tensor_list: |
| Nk, n_parts_i, max_len_i = t.shape |
|
|
| |
| if n_parts_i < max_n_parts or max_len_i < max_len: |
| new_t = t.new_full( |
| (Nk, max_n_parts, max_len), |
| pad_value |
| ) |
| new_t[:, :n_parts_i, :max_len_i] = t |
| t = new_t |
|
|
| |
| if Nk < max_Nk: |
| pad_tensor = t.new_full( |
| (max_Nk - Nk, max_n_parts, max_len), |
| pad_value |
| ) |
| t = torch.cat([t, pad_tensor], dim=0) |
|
|
| padded.append(t) |
|
|
| return torch.stack(padded) |
|
|
| def collate_fn(batch): |
| gold_entities = [] |
| for bidx, b in enumerate(batch): |
| for entity in b['gold_entities']: |
| gold_entities.append([bidx, entity]) |
|
|
| input_ids = [b["input_ids"].unsqueeze(-1) for b in batch] |
| attention_mask = [b["attention_mask"].unsqueeze(-1) for b in batch] |
| start_labels = [b["start_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch] |
| end_labels = [b["end_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch] |
|
|
| |
| input_ids = _pad_batch(input_ids, pad_value=0).squeeze(-1) |
| attention_mask = _pad_batch(attention_mask, pad_value=0).squeeze(-1) |
| start_labels = _pad_batch(start_labels, pad_value=-100).squeeze(-1).squeeze(-1) |
| end_labels = _pad_batch(end_labels, pad_value=-100).squeeze(-1).squeeze(-1) |
| |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "start_labels": start_labels, |
| "end_labels": end_labels, |
| "gold_entities": gold_entities, |
| } |
|
|
| |
| def shift_bidx(spans, batch_idx): |
| shifted = [] |
| for bidx, ent in spans: |
| new_bidx = bidx + batch_idx * batch_size |
| shifted.append((new_bidx, ent)) |
| return shifted |
|
|
| def refactor_entities(entities, save_dict): |
| i, c = [], [] |
| for bidx, (ids, lb) in entities: |
| if (bidx, ids) not in i: |
| i.append((bidx, ids)) |
| |
| if (bidx, (ids, lb)) not in c: |
| c.append((bidx, (ids, lb))) |
|
|
| save_dict['Ent-I'].extend(i) |
| save_dict['Ent-C'].extend(c) |
|
|
| def test(network, state_dicts, test_loader, eval_fn, analyzer, device, id2label, tokenizer): |
| if torch.cuda.device_count() > 1: |
| network = DataParallelProxy(network) |
| network = network.to(device) |
| network.eval() |
|
|
| eval_types = ['Ent-I', 'Ent-C'] |
| |
| all_pred = {eval_type: [] for eval_type in eval_types} |
| all_gold = {eval_type: [] for eval_type in eval_types} |
|
|
| list_input_ids = [] |
|
|
| with torch.no_grad(): |
| for batch_idx, batch in enumerate(test_loader): |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| gold_entities = batch['gold_entities'] |
|
|
| B, _, _ = input_ids.shape |
| list_input_ids.extend(input_ids.reshape(B, -1).tolist()) |
|
|
| list_start_logits = [] |
| list_end_logits = [] |
| for sd in state_dicts: |
| if torch.cuda.device_count() > 1: |
| network.module.load_state_dict(sd) |
| else: |
| network.load_state_dict(sd) |
| |
| start_logits, end_logits = network(input_ids, attention_mask) |
| list_start_logits.append(start_logits) |
| list_end_logits.append(end_logits) |
| |
| ensemble_start_logits = torch.stack(list_start_logits, dim=0).mean(dim=0) |
| ensemble_end_logits = torch.stack(list_end_logits, dim=0).mean(dim=0) |
| |
| pred_entities = extract_entities(input_ids.reshape(B, -1), ensemble_start_logits, ensemble_end_logits, id2label) |
| pred_entities = shift_bidx(pred_entities, batch_idx) |
| refactor_entities(pred_entities, all_pred) |
| |
| gold_entities = shift_bidx(gold_entities, batch_idx) |
| refactor_entities(gold_entities, all_gold) |
|
|
| |
| final_score = {} |
| for eval_type in eval_types: |
| score = eval_fn(list_to_tuple(all_pred[eval_type]), list_to_tuple(all_gold[eval_type])) |
| final_score[eval_type] = score |
|
|
| analyze_result = analyzer.analyze(list_to_tuple(all_pred['Ent-I']), list_to_tuple(all_gold['Ent-I'])) |
|
|
| |
| predictions = [] |
| for input_ids in list_input_ids: |
| predictions.append([tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)]) |
| for bidx, (ids, lb) in all_pred['Ent-C']: |
| predictions[bidx].append((tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True), lb)) |
|
|
| return final_score, analyze_result, predictions |
|
|
| |
| with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f: |
| data_train = json.load(f) |
| |
| with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f: |
| data_test = json.load(f) |
|
|
| print('Train:', len(data_train)) |
| print('Test:', len(data_test)) |
|
|
| |
| entity_types = ['O'] + sorted(list(set([e['label'] for d in data_train + data_test for e in d['entities']]))) |
| |
| label2id = {l: i for i, l in enumerate(entity_types)} |
| id2label = {i: l for l, i in label2id.items()} |
|
|
| |
| zero_entities_idxes = [] |
| for idx, d in enumerate(data_train): |
| if len(d['entities']) == 0: |
| zero_entities_idxes.append(idx) |
|
|
| n_zero_entities_samples = len(zero_entities_idxes) |
| n_has_entities_samples = len(data_train) - n_zero_entities_samples |
|
|
| random.seed(42) |
| k = min(int(n_has_entities_samples * zero_entities_rate), len(zero_entities_idxes)) |
| sampled_zero_entities_idxes = random.sample(zero_entities_idxes, k) |
|
|
| new_data_train = [] |
| for idx, d in enumerate(data_train): |
| if len(d['entities']) == 0: |
| if idx in sampled_zero_entities_idxes: |
| new_data_train.append(d) |
| else: |
| new_data_train.append(d) |
| data_train = new_data_train |
|
|
| print('Train:', len(data_train)) |
|
|
| |
| if debug_only: |
| data_train = data_train[:10] |
| data_test = data_test[:10] |
|
|
| print('Train:', len(data_train)) |
| print('Test:', len(data_test)) |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(backbone_model_name) |
|
|
| |
| print('Experiment name:', state_dict_save_name) |
|
|
| |
| if not test_only: |
| full_idxes = np.array(range(len(data_train))) |
| training_logs, best_models, last_models = [], [], [] |
| start_training_time = time.time() |
| for seed in SEEDS: |
| kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed) |
| for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)): |
| if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx: |
| continue |
| set_seed(seed) |
| |
| train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx] |
| |
| trainset = KLTNDataset(data_train, train_idxes, label2id, tokenizer, **train_memory_params) |
| valset = KLTNDataset(data_train, val_idxes, label2id, tokenizer, **val_memory_params) |
| |
| generator = torch.Generator() |
| generator.manual_seed(seed) |
| train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params) |
| val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params) |
| |
| my_model = IEModel( |
| num_labels=len(label2id), |
| **model_params |
| ) |
| total_params = sum(p.numel() for p in my_model.parameters()) |
| print(f"Total params: {total_params:,}") |
| |
| |
| encoder_params = set(map(id, my_model.encoder.parameters())) |
| other_params = [ |
| p for p in my_model.parameters() |
| if id(p) not in encoder_params |
| ] |
| optimizer = optim.AdamW([ |
| {"params": my_model.encoder.parameters(), "lr": 2e-5}, |
| {"params": other_params} |
| ], lr=5e-4) |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6) |
| |
| loss_fn = CustomLoss( |
| **loss_func_params |
| ) |
| eval_fn = CustomEvalFn(**eval_func_params) |
| trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}' |
| trainer = Trainer(**trainer_params) |
| |
| print(f'Start Training Fold {fold_idx}...') |
| training_log, best_model, last_model = trainer.fit( |
| my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, eval_fn, |
| start_epoch=1, start_training_time=start_training_time, id2label=id2label |
| ) |
| |
| training_logs.append(training_log) |
| best_models.append(best_model) |
| last_models.append(last_model) |
|
|
| |
| def load_all_state_dicts(folder): |
| files = [] |
| |
| for file in os.listdir(folder): |
| if file.endswith(".pt") or file.endswith(".pth"): |
| m = re.search(r"f(\d+)", file) |
| if m: |
| fold = int(m.group(1)) |
| files.append((fold, file)) |
|
|
| |
| files.sort(key=lambda x: x[0]) |
|
|
| state_dicts = [] |
| for fold, file in files: |
| path = os.path.join(folder, file) |
| print(f"Loading fold {fold}: {file}") |
| state_dict = torch.load(path, map_location="cpu") |
| state_dicts.append(state_dict) |
|
|
| return state_dicts |
|
|
| if test_only: |
| snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"]) |
| get_ipython().system('rm -rf .cache .gitattributes') |
| |
| best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s") |
| last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts") |
|
|
| |
| os.makedirs(f'{checkpoints_dir}/results', exist_ok=True) |
| testset = KLTNDataset(data_test, range(len(data_test)), label2id, tokenizer, **val_memory_params) |
| generator = torch.Generator() |
| test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params) |
| eval_fn = CustomEvalFn(**eval_func_params) |
| analyzer = SpanErrorAnalyzer() |
| my_model = IEModel( |
| num_labels=len(label2id), |
| **model_params |
| ) |
| total_params = sum(p.numel() for p in my_model.parameters()) |
| print(f"Total params: {total_params:,}") |
|
|
| |
| start_time = time.time() |
| result_test = None |
| analyze_result = None |
|
|
| best_score, best_analyze_result, best_pred_test = test(my_model, best_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer) |
| last_score, last_analyze_result, last_pred_test = test(my_model, last_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer) |
|
|
| result_test = {"Best model": best_score, "Last model": last_score} |
| analyze_result = {"Best model": best_analyze_result, "Last model": last_analyze_result} |
| analyze_result_sumary = {"Best model": best_analyze_result['summary'], "Last model": last_analyze_result['summary']} |
| pred_test = {"Best model": best_pred_test, "Last model": last_pred_test} |
|
|
| with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f: |
| json.dump(result_test, f, ensure_ascii=False, indent=2) |
|
|
| with open(f"{checkpoints_dir}/results/{state_dict_save_name}_error_analyze_result.json", "w", encoding="utf-8") as f: |
| json.dump(analyze_result, f, ensure_ascii=False, indent=2) |
|
|
| with open(f"{checkpoints_dir}/results/{state_dict_save_name}_pred_test.json", "w", encoding="utf-8") as f: |
| json.dump(pred_test, f, ensure_ascii=False, indent=2) |
| |
| print('Test:', time.time() - start_time, 's --> Done!') |
| print(json.dumps(analyze_result_sumary, ensure_ascii=False, indent=4)) |
|
|
| |
| best_pred_test[:10] |
|
|
| |
| last_pred_test[:10] |
|
|
| |
| def dict_to_df(data): |
| row_tuples = [] |
| row_values = [] |
|
|
| metrics = ["precision", "recall", "f1"] |
|
|
| |
| first_model = next(iter(data.values())) |
|
|
| |
| eval_keys = list(first_model.keys()) |
|
|
| for eval_key in eval_keys: |
| row_tuples.append(eval_key) |
| row = {} |
|
|
| for model_name, model_data in data.items(): |
| for metric in metrics: |
| row[(model_name, metric)] = model_data[eval_key][metric] |
|
|
| row_values.append(row) |
|
|
| |
| df = pd.DataFrame(row_values) |
|
|
| |
| df.columns = pd.MultiIndex.from_tuples(df.columns) |
|
|
| |
| df.index = pd.Index(row_tuples, name="evaluation") |
|
|
| |
| sort_keys = [] |
| if ("Best model", "f1") in df.columns: |
| sort_keys.append(("Best model", "f1")) |
| if ("Last model", "f1") in df.columns: |
| sort_keys.append(("Last model", "f1")) |
|
|
| if sort_keys: |
| df = df.sort_values(by=sort_keys, ascending=False) |
|
|
| return df |
|
|
| result_test_df = dict_to_df(result_test) |
| result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx") |
| result_test_df |
|
|
| |
| key = ("Best model", "f1") |
| result_test_df_best = result_test_df.sort_values(by=key, ascending=False).groupby(level="evaluation").head(1) |
| result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx") |
| result_test_df_best |
|
|
| |
| def get_avg_best_score(logs): |
| return float(np.mean([list(log.values())[-1]['best_score'] for log in logs])) |
| |
| def get_avg_log(logs, epochs): |
| avg_log = {} |
|
|
| for epoch in range(1, epochs + 1): |
| val_score = 0.0 |
| train_loss = 0.0 |
| n_eval = 0 |
|
|
| for idx in range(len(logs)): |
| log = logs[idx].get(epoch, logs[idx].get(str(epoch))) |
| if log is None: |
| continue |
|
|
| val_score += log.get('val_score', 0.0) |
| train_loss += log.get('train_loss', 0.0) |
| n_eval += 1 |
|
|
| if n_eval == 0: |
| continue |
|
|
| avg_log[epoch] = { |
| 'train_loss': train_loss / n_eval, |
| 'val_score': val_score / n_eval if val_score != 0 else float('inf') |
| } |
|
|
| return avg_log |
|
|
| def parse_label_key(label: str): |
| try: |
| first = float(label.split('_', 1)[0]) |
| last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0]) |
| return first, last |
| except: |
| return (0, 0) |
|
|
| def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)): |
| fig, axes = plt.subplots(1, 2, figsize=figsize) |
|
|
| |
| for name, log in logs_dict.items(): |
| epochs = sorted(log.keys()) |
| train_loss = [log[e]['train_loss'] for e in epochs] |
| axes[0].plot(epochs, train_loss, label=name) |
|
|
| axes[0].set_xlabel('Epoch') |
| axes[0].set_ylabel('Train Loss') |
| axes[0].set_title('Training Loss') |
| axes[0].grid(True) |
|
|
| |
| for name, log in logs_dict.items(): |
| epochs = sorted(log.keys()) |
| val_score = [log[e]['val_score'] for e in epochs] |
| axes[1].plot(epochs, val_score, label=name) |
|
|
| axes[1].set_xlabel('Epoch') |
| axes[1].set_ylabel('Validation Score') |
| axes[1].set_title('Validation Score') |
| axes[1].grid(True) |
|
|
| |
| handles, labels = axes[0].get_legend_handles_labels() |
| pairs = list(zip(handles, labels)) |
| pairs_sorted = sorted( |
| pairs, |
| key=lambda x: parse_label_key(x[1]) |
| ) |
| handles_sorted, labels_sorted = zip(*pairs_sorted) |
| |
| axes[0].legend( |
| handles_sorted, |
| labels_sorted, |
| loc='center left', |
| bbox_to_anchor=(1.01, 0.5), |
| borderaxespad=0. |
| ) |
|
|
| plt.tight_layout(rect=[0, 0, 1, 1]) |
|
|
| if save_path is not None: |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
| plt.show() |
|
|
| |
| if not test_only: |
| snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*.json"]) |
| get_ipython().system('rm -rf .cache .gitattributes') |
|
|
| |
| if not test_only: |
| experiments = {} |
| for experiment in os.listdir(pretrained_dir): |
| if '.virtual_documents' in experiment: |
| continue |
| experiment_logs = [] |
| try: |
| for seed in SEEDS: |
| for fold_idx in range(nfolds): |
| with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f: |
| experiment_log = json.load(f) |
| experiment_logs.append(experiment_log) |
| except: |
| pass |
| experiments[experiment] = get_avg_log(experiment_logs, 1000) |
| experiments[state_dict_save_name] = get_avg_log(training_logs, 1000) |
|
|
| |
| if not test_only: |
| score = get_avg_best_score(training_logs) |
| state_dict_save_name, score |
|
|
| |
| if not test_only: |
| plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5)) |
|
|
|
|