| |
| get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch]') |
|
|
| |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, TensorDataset, DataLoader |
| import torch.nn.functional as F |
| import albumentations as albu |
| from transformers import AutoTokenizer, AutoModel |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from positional_encodings.torch_encodings import PositionalEncoding1D |
|
|
| from sklearn.metrics import f1_score |
| from sklearn.preprocessing import MinMaxScaler, StandardScaler |
| from scipy.spatial.transform import Rotation as R |
| from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold |
| from sklearn.metrics import precision_recall_fscore_support |
| from timm.utils import ModelEmaV3 |
| import timm |
|
|
| import os |
| import gc |
| import json |
| from pathlib import Path |
| import pickle |
| from tqdm.auto import tqdm |
| import copy |
| import numpy as np |
| import pandas as pd |
| import polars as pl |
| from PIL import Image |
| import time |
| from tqdm import tqdm |
| from matplotlib import pyplot as plt |
| import seaborn as sns |
| from multiprocessing import Manager as MemoryManager |
| from functools import lru_cache |
| import shutil |
| import glob |
| import cv2 |
| import random |
| import re |
| import joblib |
| import math |
| from huggingface_hub import HfApi, snapshot_download |
| import evaluate |
| from underthesea import word_tokenize as vi_tokenize_tool |
| import spacy |
| en_tokenize_tool = spacy.load("en_core_web_sm") |
| from collections import defaultdict, Counter |
|
|
| |
| |
| SEEDS = [26092004] |
| topk = 1 |
| nfolds = 5 |
| only_fold_idx = 0 |
| test_only = 0 |
| debug_only = 0 |
|
|
| |
| dataset = 'kltn/only_actions' |
| root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' |
| train_dir = f'{root_dir}' |
| |
| test_dir = f'{root_dir}' |
|
|
| |
|
|
| |
| epochs = 18 if not debug_only else 2 |
| batch_size = 32 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| repo_name = 'SS3M/kltn-experiments' |
| state_dict_save_name = "1_pointer_base_actions_4" |
| checkpoints_dir = state_dict_save_name |
| pretrained_dir = "/kaggle/working" |
| os.makedirs(f'{checkpoints_dir}', exist_ok=True) |
|
|
| backbone_model_name = "bert-base-uncased" if dataset == "casie" else "vinai/phobert-base" |
| word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == "casie" else vi_tokenize_tool(text) |
| max_len_dict = { |
| 'kltn/only_issues': 52, |
| 'kltn/only_actions': 69, |
| 'vhe': 51, |
| 'bkee': 62, |
| 'casie': 40, |
| } |
| zero_events_rate_dict = { |
| 'kltn/only_issues': 0, |
| 'kltn/only_actions': 0.2, |
| 'vhe': 1000, |
| 'bkee': 1000, |
| 'casie': 1000, |
| } |
|
|
| max_len = max_len_dict[dataset] |
| max_n_parts = 1 |
| max_span_len = 14 |
| zero_events_rate = zero_events_rate_dict[dataset] |
|
|
| |
| trainer_params = { |
| "training_time": "00:11:30:00", |
| "eval_mode": "max", |
| "topk": topk, |
| "save_name": state_dict_save_name, |
| "save_best": True, |
| "save_last": True, |
| "device": device, |
| "logging": True, |
| "logging_file": True, |
| "checkpoints_dir": checkpoints_dir, |
| "early_stopping": 30, |
| "eval_from_ratio": 0.4, |
| "eval_every": 1, |
| "schedule_in_step": False, |
| "use_ema": True, |
| "ema_from_ratio": 0.3, |
| "ema_decay": 0.9995, |
| "max_grad_norm": 200.0, |
| "return_best": True, |
| "return_last": True, |
| } |
|
|
| |
| train_memory_params = { |
| 'max_len': max_len, |
| 'max_n_parts': max_n_parts, |
| } |
| val_memory_params = { |
| 'max_len': max_len, |
| 'max_n_parts': max_n_parts, |
| } |
|
|
| |
| def seed_worker(worker_id): |
| worker_seed = torch.initial_seed() % 2**32 |
| np.random.seed(worker_seed) |
| random.seed(worker_seed) |
| |
| train_loader_params = { |
| 'batch_size': batch_size, |
| 'shuffle': True, |
| 'pin_memory':True, |
| 'num_workers': 2, |
| 'drop_last': False, |
| 'worker_init_fn': seed_worker, |
| 'persistent_workers': False, |
| } |
| val_loader_params = { |
| 'batch_size': batch_size, |
| 'shuffle': False, |
| 'pin_memory':True, |
| 'num_workers': 1, |
| 'drop_last': False, |
| 'worker_init_fn': seed_worker, |
| 'persistent_workers': False, |
| } |
|
|
| |
| model_params = { |
| 'backbone_model_name': backbone_model_name, |
| } |
|
|
| |
| loss_func_params = { |
| 'lambda_trg_ce': 1.0, |
| 'lambda_arg_ce': 1.0, |
| } |
| eval_func_params = {} |
|
|
| |
| optim_params = { |
| 'name': 'AdamW', |
| 'lr': 1e-4, |
| 'weight_decay': 1e-4, |
| } |
| scheduler_params = { |
| 'name': 'CosineAnnealingLR', |
| 'T_max': 20, |
| 'eta_min': 1e-6 |
| } |
|
|
| |
| def set_seed(seed=42): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.use_deterministic_algorithms(False) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
| |
| class CustomLoss(nn.Module): |
| def __init__( |
| self, |
| lambda_trg_ce=1.0, |
| lambda_arg_ce=1.0, |
| ): |
| super().__init__() |
| self.lambda_trg_ce = lambda_trg_ce |
| self.lambda_arg_ce = lambda_arg_ce |
| self.ce = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
| def forward( |
| self, |
| trg_start_logits, trg_start_labels, |
| trg_end_logits, trg_end_labels, |
| arg_start_logits, pred_arg_start_labels, |
| arg_end_logits, pred_arg_end_labels, |
| ): |
| device = trg_start_logits.device |
|
|
| |
| B, N, C = trg_start_logits.shape |
| |
| trg_start_logits_flat = trg_start_logits.view(B * N, C) |
| trg_start_labels_flat = trg_start_labels.view(-1) |
| |
| trg_start_loss = self.ce( |
| trg_start_logits_flat, |
| trg_start_labels_flat |
| ) |
| |
| |
| B, N, C = trg_end_logits.shape |
| |
| trg_end_logits_flat = trg_end_logits.view(B * N, C) |
| trg_end_labels_flat = trg_end_labels.view(-1) |
| |
| trg_end_loss = self.ce( |
| trg_end_logits_flat, |
| trg_end_labels_flat |
| ) |
|
|
| |
| B, K, M, C = arg_start_logits.shape |
| arg_start_logits_flat = arg_start_logits.view(B * K * M, C) |
| arg_start_labels_flat = pred_arg_start_labels.view(-1) |
|
|
| arg_mask = (arg_start_labels_flat != -100) |
|
|
| if arg_mask.sum() == 0: |
| arg_start_loss = torch.tensor(0.0, device=device) |
| else: |
| arg_start_loss = self.ce(arg_start_logits_flat, arg_start_labels_flat) |
| |
| B, K, M, C = arg_end_logits.shape |
| arg_end_logits_flat = arg_end_logits.view(B * K * M, C) |
| arg_end_labels_flat = pred_arg_end_labels.view(-1) |
|
|
| arg_mask = (arg_end_labels_flat != -100) |
|
|
| if arg_mask.sum() == 0: |
| arg_end_loss = torch.tensor(0.0, device=device) |
| else: |
| arg_end_loss = self.ce(arg_end_logits_flat, arg_end_labels_flat) |
|
|
| |
| total_loss = ( |
| self.lambda_trg_ce * (trg_start_loss + trg_end_loss) + |
| self.lambda_arg_ce * (arg_start_loss + arg_end_loss) |
| ) |
|
|
| return { |
| "total": total_loss, |
| "trg_start_loss": trg_start_loss, |
| "trg_end_loss": trg_end_loss, |
| "arg_start_loss": arg_start_loss, |
| "arg_end_loss": arg_end_loss, |
| } |
|
|
| |
| |
|
|
| |
| class CustomEvalFn(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def compute_f1(self, tp, fp, fn): |
| precision = tp / (tp + fp + 1e-8) |
| recall = tp / (tp + fn + 1e-8) |
| f1 = 2 * precision * recall / (precision + recall + 1e-8) |
| return precision, recall, f1 |
|
|
| def forward(self, pred, gold): |
| pred_set = set(pred) |
| gold_set = set(gold) |
|
|
| tp = len(pred_set & gold_set) |
| fp = len(pred_set - gold_set) |
| fn = len(gold_set - pred_set) |
|
|
| precision, recall, f1 = self.compute_f1(tp, fp, fn) |
|
|
| return { |
| f"precision": precision, |
| f"recall": recall, |
| f"f1": f1, |
| } |
|
|
| class SpanErrorAnalyzer: |
| def __init__(self, pad_token_id=0): |
| self.pad_token_id = pad_token_id |
|
|
| |
| def _to_set(self, data): |
| """ |
| data: list of (b, tuple(ids)) |
| -> dict[b] = set(tuple(ids)) |
| """ |
| res = defaultdict(set) |
| for b, ids in data: |
| ids = tuple([i for i in ids if i != self.pad_token_id]) |
| if len(ids) > 0: |
| res[b].add(ids) |
| return res |
|
|
| def _iou(self, a, b): |
| """ |
| a, b: tuple(ids) |
| """ |
| set_a, set_b = set(a), set(b) |
| inter = len(set_a & set_b) |
| union = len(set_a | set_b) |
| if union == 0: |
| return 0.0 |
| return inter / union |
|
|
| def _boundary_error(self, pred, gold): |
| """ |
| đo lệch boundary dựa trên overlap prefix/suffix |
| """ |
| |
| left = 0 |
| for i in range(min(len(pred), len(gold))): |
| if pred[i] == gold[i]: |
| left += 1 |
| else: |
| break |
|
|
| |
| right = 0 |
| for i in range(1, min(len(pred), len(gold)) + 1): |
| if pred[-i] == gold[-i]: |
| right += 1 |
| else: |
| break |
|
|
| return { |
| "left_match": left, |
| "right_match": right, |
| "pred_len": len(pred), |
| "gold_len": len(gold), |
| } |
|
|
| |
| def analyze(self, preds, golds): |
| pred_map = self._to_set(preds) |
| gold_map = self._to_set(golds) |
|
|
| all_batches = set(pred_map.keys()) | set(gold_map.keys()) |
|
|
| stats = Counter() |
|
|
| detailed_errors = [] |
|
|
| for b in all_batches: |
| pset = pred_map.get(b, set()) |
| gset = gold_map.get(b, set()) |
|
|
| matched_gold = set() |
|
|
| |
| for p in pset: |
| if p in gset: |
| stats["exact_match"] += 1 |
| matched_gold.add(p) |
| else: |
| |
| best_iou = 0 |
| best_g = None |
|
|
| for g in gset: |
| iou = self._iou(p, g) |
| if iou > best_iou: |
| best_iou = iou |
| best_g = g |
|
|
| if best_iou > 0: |
| stats["partial_match"] += 1 |
|
|
| boundary = self._boundary_error(p, best_g) |
|
|
| detailed_errors.append({ |
| "type": "boundary_error", |
| "batch": b, |
| "pred": p, |
| "gold": best_g, |
| "iou": best_iou, |
| **boundary |
| }) |
| else: |
| if b not in gold_map: |
| stats["no_event_sample"] += 1 |
| err_type = "no_event_sample" |
| else: |
| stats["completely_wrong"] += 1 |
| err_type = "completely_wrong" |
| |
| detailed_errors.append({ |
| "type": err_type, |
| "batch": b, |
| "pred": p |
| }) |
|
|
| |
| for g in gset: |
| if g not in matched_gold: |
| |
| overlap = any(self._iou(p, g) > 0 for p in pset) |
|
|
| if overlap: |
| stats["miss_with_overlap"] += 1 |
| else: |
| stats["miss"] += 1 |
|
|
| detailed_errors.append({ |
| "type": "miss", |
| "batch": b, |
| "gold": g |
| }) |
|
|
| return { |
| "summary": { |
| "exact_match": (stats["exact_match"], stats["exact_match"] / len(preds)), |
| "partial_match": (stats["partial_match"], stats["partial_match"] / len(preds)), |
| "no_event_sample": (stats["no_event_sample"], stats["no_event_sample"] / len(preds)), |
| "completely_wrong": (stats["completely_wrong"], stats["completely_wrong"] / len(preds)), |
| "miss": (stats["miss"], stats["miss"] / len(golds)), |
| "miss_with_overlap": (stats["miss_with_overlap"], stats["miss_with_overlap"] / len(golds)), |
| }, |
| "details": detailed_errors |
| } |
|
|
| |
| |
| def fix_bio_ids_batch(label_ids): |
| """ |
| label_ids: (B, L) |
| return: (B, L) fixed |
| """ |
| B, L = label_ids.shape |
| fixed = label_ids.clone() |
|
|
| for b in range(B): |
| for i in range(L): |
| tag = fixed[b, i].item() |
|
|
| if tag == 0: |
| continue |
|
|
| |
| if tag % 2 == 0: |
| if i == 0 or fixed[b, i-1].item() == 0: |
| fixed[b, i] = tag - 1 |
| else: |
| prev_tag = fixed[b, i-1].item() |
|
|
| if prev_tag == 0: |
| fixed[b, i] = tag - 1 |
| else: |
| prev_type = (prev_tag - 1) // 2 |
| curr_type = (tag - 1) // 2 |
|
|
| if prev_type != curr_type: |
| fixed[b, i] = tag - 1 |
|
|
| return fixed |
|
|
| def extract_trigger_spans_batch_tensor(label_ids): |
| """ |
| label_ids: (B, L) |
| return: |
| spans_tensor: (B, N, 2) # (s, e), pad = (0,0) |
| """ |
| B, L = label_ids.shape |
| device = label_ids.device |
|
|
| all_spans = [] |
| max_n = 0 |
|
|
| |
| for b in range(B): |
| spans = [] |
| i = 0 |
|
|
| while i < L: |
| tag = label_ids[b, i].item() |
|
|
| if tag == 0: |
| i += 1 |
| continue |
|
|
| |
| if tag % 2 == 1: |
| type_id = (tag - 1) // 2 |
| s = i |
| e = i |
| i += 1 |
|
|
| while i < L: |
| next_tag = label_ids[b, i].item() |
|
|
| if next_tag == 0: |
| break |
|
|
| next_type = (next_tag - 1) // 2 |
|
|
| if next_tag % 2 == 0 and next_type == type_id: |
| e = i |
| i += 1 |
| else: |
| break |
|
|
| spans.append((s, e)) |
| else: |
| i += 1 |
|
|
| all_spans.append(spans) |
| max_n = max(max_n, len(spans)) |
|
|
| |
| if max_n == 0: |
| |
| return torch.zeros((B, 0, 2), dtype=torch.long, device=device) |
|
|
| spans_tensor = torch.zeros((B, max_n, 2), dtype=torch.long, device=device) |
|
|
| for b in range(B): |
| for i, (s, e) in enumerate(all_spans[b]): |
| spans_tensor[b, i, 0] = s |
| spans_tensor[b, i, 1] = e |
|
|
| return spans_tensor |
| |
| def get_span_repr(hidden, spans): |
| B, L, H = hidden.size() |
| K = spans.size(1) |
| device = hidden.device |
|
|
| start = spans[:, :, 0] |
| end = spans[:, :, 1] |
|
|
| h_s = torch.gather(hidden, 1, start.unsqueeze(-1).expand(-1, -1, H)) |
| h_e = torch.gather(hidden, 1, end.unsqueeze(-1).expand(-1, -1, H)) |
|
|
| h_diff = h_s - h_e |
| h_prod = h_s * h_e |
|
|
| |
| span_repr = torch.cat( |
| [h_s, h_e, h_diff, h_prod], |
| dim=-1 |
| ) |
|
|
| return span_repr |
| |
| class MLP(nn.Module): |
| def __init__(self, in_size, hid_size, out_size): |
| super().__init__() |
| self.model = nn.Sequential( |
| nn.Linear(in_size, hid_size), |
| nn.ReLU(), |
| nn.Linear(hid_size, out_size) |
| ) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
| class IEModel(nn.Module): |
| def __init__(self, backbone_model_name, num_trg_labels, num_arg_labels): |
| super().__init__() |
| self.encoder = AutoModel.from_pretrained(backbone_model_name) |
| hidden_size = self.encoder.config.hidden_size |
|
|
| self.trg_start_classifier = MLP(hidden_size, hidden_size, num_trg_labels) |
| self.trg_end_classifier = MLP(hidden_size, hidden_size, num_trg_labels) |
| |
| self.trg_repr_proj = MLP(hidden_size*4, hidden_size, hidden_size) |
| self.arg_start_classifier = MLP(hidden_size*2, hidden_size, num_arg_labels) |
| self.arg_end_classifier = MLP(hidden_size*2, hidden_size, num_arg_labels) |
|
|
| def encode(self, input_ids, attention_mask): |
| B, n_parts, L = input_ids.shape |
| input_ids = input_ids.view(-1, L) |
| attention_mask = attention_mask.view(-1, L) |
| |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| hidden_states = outputs.last_hidden_state |
| |
| hidden_states = hidden_states.view(B, n_parts, L, -1).reshape(B, n_parts*L, -1) |
| return hidden_states |
|
|
| def get_trg_logits(self, hidden_states): |
| trg_start_logits = self.trg_start_classifier(hidden_states) |
| trg_end_logits = self.trg_end_classifier(hidden_states) |
| return trg_start_logits, trg_end_logits |
|
|
| def get_arg_logits(self, hidden_states, trg_repr): |
| B, L, H = hidden_states.shape |
| _, N, _ = trg_repr.shape |
| |
| hidden_expand = hidden_states.unsqueeze(1).expand(-1, N, -1, -1) |
| trg_expand = trg_repr.unsqueeze(2).expand(-1, -1, L, -1) |
| |
| hidden_trg_repr = torch.cat([hidden_expand, trg_expand], dim=-1) |
| arg_start_logits = self.arg_start_classifier(hidden_trg_repr) |
| arg_end_logits = self.arg_end_classifier(hidden_trg_repr) |
| |
| return arg_start_logits, arg_end_logits |
|
|
| def forward(self, input_ids, attention_mask, trg_spans=None): |
| hidden_states = self.encode(input_ids, attention_mask) |
| |
| trg_start_logits, trg_end_logits = self.get_trg_logits(hidden_states) |
|
|
| if trg_spans is None: |
| trg_labels = torch.argmax(trg_logits, dim=-1) |
| trg_labels = fix_bio_ids_batch(trg_labels) |
| trg_spans = extract_trigger_spans_batch_tensor(trg_labels) |
|
|
| trg_repr = get_span_repr(hidden_states, trg_spans) |
| |
| trg_repr = self.trg_repr_proj(trg_repr) |
| arg_start_logits, arg_end_logits = self.get_arg_logits(hidden_states, trg_repr) |
|
|
| return trg_start_logits, trg_end_logits, arg_start_logits, arg_end_logits, trg_spans |
|
|
| def test(): |
| model = nn.DataParallel(IEModel(backbone_model_name, 7, 5)).to(device) |
| model.eval() |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Total params: {total_params:,}") |
|
|
| vocab_size = model.module.encoder.config.vocab_size |
| max_len = model.module.encoder.config.max_position_embeddings |
|
|
| bz = 32 |
| i = torch.randint(0, vocab_size, (bz, 5, 10)).to(device) |
| a = torch.ones(bz, 5, 10).to(device) |
| g = torch.ones(bz, 3, 2, dtype=torch.long).to(device) |
|
|
| with torch.no_grad(): |
| r = model(i, a, g) |
|
|
| if type(r) == tuple: |
| print([r[i].shape for i in range(len(r))]) |
| else: |
| print(r.shape) |
|
|
| test() |
|
|
| |
| def configure_optimizers(network, optim_params, scheduler_params): |
| try: |
| optim_params = copy.copy(optim_params) |
| scheduler_params = copy.copy(scheduler_params) |
| |
| optim_name = optim_params.pop('name') |
| scheduler_name = scheduler_params.pop('name') |
| |
| optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None) |
| scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None) |
| |
| if optimizer_cls is None: |
| raise ValueError(f"Optimizer '{optim_name}' is not available!") |
| |
| optimizer = optimizer_cls(network.parameters(), **optim_params) |
| |
| scheduler = None |
| if scheduler_params and scheduler_cls: |
| scheduler = scheduler_cls(optimizer, **scheduler_params) |
| |
| return optimizer, scheduler |
|
|
| except KeyError as e: |
| raise ValueError(f"Missing {e} in config!!") |
| |
| def freeze(self, model): |
| model.eval() |
| for param in model.parameters(): |
| param.requires_grad = False |
|
|
| def unfreeze(self, model): |
| model.train() |
| for param in model.parameters(): |
| param.requires_grad = True |
|
|
| def reduce_batch_size(loader, ratio=0.5): |
| new_bs = max(1, int(loader.batch_size * ratio)) |
|
|
| shuffle = isinstance(loader.sampler, RandomSampler) |
|
|
| new_loader = DataLoader( |
| dataset=loader.dataset, |
| batch_size=new_bs, |
| shuffle=shuffle, |
| sampler=None if shuffle else loader.sampler, |
| num_workers=loader.num_workers, |
| collate_fn=loader.collate_fn, |
| pin_memory=loader.pin_memory, |
| drop_last=loader.drop_last, |
| timeout=loader.timeout, |
| worker_init_fn=loader.worker_init_fn, |
| multiprocessing_context=loader.multiprocessing_context, |
| generator=loader.generator, |
| prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None, |
| persistent_workers=loader.persistent_workers, |
| pin_memory_device=loader.pin_memory_device |
| ) |
|
|
| return new_loader |
|
|
| def list_to_tuple(x): |
| if isinstance(x, (list, tuple)): |
| return tuple(list_to_tuple(i) for i in x) |
| return x |
|
|
| def fmt(x): |
| if isinstance(x, float): |
| return round(x, 5) |
| if isinstance(x, dict): |
| return {k: fmt(v) for k, v in x.items()} |
| if isinstance(x, list): |
| return [fmt(v) for v in x] |
| return x |
| |
| class ModelEmaV3Proxy(ModelEmaV3): |
| def __getattr__(self, name): |
| try: |
| return super().__getattr__(name) |
| except AttributeError: |
| return getattr(self.module, name) |
| |
| class DataParallelProxy(nn.DataParallel): |
| def __getattr__(self, name): |
| try: |
| return super().__getattr__(name) |
| except AttributeError: |
| attr = getattr(self.module, name) |
|
|
| if callable(attr): |
| def wrapper(*args, **kwargs): |
| return self._parallel_apply_method(name, *args, **kwargs) |
| return wrapper |
|
|
| return attr |
|
|
| def _parallel_apply_method(self, method_name, *inputs, **kwargs): |
| if not self.device_ids: |
| return getattr(self.module, method_name)(*inputs, **kwargs) |
|
|
| inputs_scattered, kwargs_scattered = self.scatter(inputs, kwargs, self.device_ids) |
|
|
| replicas = self.replicate(self.module, self.device_ids) |
|
|
| outputs = self.parallel_apply( |
| [getattr(replica, method_name) for replica in replicas], |
| inputs_scattered, |
| kwargs_scattered |
| ) |
|
|
| return self.gather(outputs, self.output_device) |
|
|
| def map_arg_labels(all_arg_labels, trg_spans, pred_spans): |
| """ |
| all_arg_labels: (B, N, L) |
| trg_spans: (B, N, 2) |
| pred_spans: (B, M, 2) |
| |
| return: |
| pred_arg_labels: (B, M, L) |
| """ |
| B, N, L = all_arg_labels.shape |
| _, M, _ = pred_spans.shape |
|
|
| device = all_arg_labels.device |
|
|
| |
| match = ( |
| (pred_spans.unsqueeze(2) == trg_spans.unsqueeze(1)) |
| .all(dim=-1) |
| ) |
|
|
| |
| match_idx = match.float().argmax(dim=2) |
| has_match = match.any(dim=2) |
|
|
| |
| gather_idx = match_idx.unsqueeze(-1).expand(-1, -1, L) |
|
|
| gathered = torch.gather( |
| all_arg_labels, |
| dim=1, |
| index=gather_idx |
| ) |
|
|
| |
| |
| base = torch.zeros((B, M, L), dtype=torch.long, device=device) |
|
|
| |
| ignore_mask = (all_arg_labels[:, 0] == -100).unsqueeze(1).expand(-1, M, -1) |
| base[ignore_mask] = -100 |
|
|
| |
| pred_arg_labels = torch.where( |
| has_match.unsqueeze(-1), |
| gathered, |
| base |
| ) |
|
|
| return pred_arg_labels.long() |
|
|
| def decode_spans(start_labels, end_labels): |
| """ |
| start_labels/end_labels: (L,) |
| return: [(s, e, label_id)] |
| """ |
|
|
| L = len(start_labels) |
|
|
| used_start = set() |
| used_end = set() |
|
|
| spans = [] |
|
|
| for s in range(L): |
|
|
| s_label = start_labels[s] |
|
|
| if s_label == 0: |
| continue |
|
|
| if s in used_start: |
| continue |
|
|
| nearest_e = None |
|
|
| for e in range(s, L): |
|
|
| if e in used_end: |
| continue |
|
|
| e_label = end_labels[e] |
|
|
| if e_label == s_label: |
| nearest_e = e |
| break |
|
|
| if nearest_e is None: |
| continue |
|
|
| used_start.add(s) |
| used_end.add(nearest_e) |
|
|
| spans.append((s, nearest_e, s_label)) |
|
|
| return spans |
|
|
| def decode_spans_batch(start_labels, end_labels): |
| """ |
| Args: |
| start_labels: (B, L) |
| end_labels: (B, L) |
| |
| Returns: |
| spans_tensor: (B, N, 2) |
| |
| N = số span lớn nhất trong batch |
| padding = (0, 0) |
| """ |
|
|
| B, L = start_labels.shape |
|
|
| all_spans = [] |
| max_n = 0 |
|
|
| for bidx in range(B): |
|
|
| used_start = set() |
| used_end = set() |
|
|
| spans = [] |
|
|
| for s in range(L): |
|
|
| s_label = start_labels[bidx, s].item() |
|
|
| if s_label == 0: |
| continue |
|
|
| if s in used_start: |
| continue |
|
|
| nearest_e = None |
|
|
| for e in range(s, L): |
|
|
| if e in used_end: |
| continue |
|
|
| e_label = end_labels[bidx, e].item() |
|
|
| if e_label == s_label: |
| nearest_e = e |
| break |
|
|
| if nearest_e is None: |
| continue |
|
|
| used_start.add(s) |
| used_end.add(nearest_e) |
|
|
| spans.append((s, nearest_e)) |
|
|
| all_spans.append(spans) |
| max_n = max(max_n, len(spans)) |
|
|
| |
| spans_tensor = torch.zeros( |
| (B, max_n, 2), |
| dtype=torch.long, |
| device=start_labels.device |
| ) |
|
|
| for bidx, spans in enumerate(all_spans): |
| for n, (s, e) in enumerate(spans): |
| spans_tensor[bidx, n, 0] = s |
| spans_tensor[bidx, n, 1] = e |
|
|
| return spans_tensor |
|
|
| def extract_arguments( |
| input_ids, |
| trg_start_logits, |
| trg_end_logits, |
| arg_start_logits, |
| arg_end_logits, |
| pred_trg_spans, |
| id2label |
| ): |
| """ |
| input_ids: (B, L) |
| |
| trg_start_logits: (B, L, C_trg) |
| trg_end_logits: (B, L, C_trg) |
| |
| arg_start_logits: (B, N, L, C_arg) |
| arg_end_logits: (B, N, L, C_arg) |
| |
| pred_trg_spans: (B, N, 2) |
| |
| id2label = { |
| 'Trg': {id: label}, |
| 'Arg': {id: label} |
| } |
| """ |
|
|
| B, L = input_ids.shape |
|
|
| |
| trg_start_ids = torch.argmax(trg_start_logits, dim=-1) |
| trg_end_ids = torch.argmax(trg_end_logits, dim=-1) |
|
|
| |
| trg_spans = [] |
|
|
| for bidx in range(B): |
| spans = decode_spans( |
| trg_start_ids[bidx].tolist(), |
| trg_end_ids[bidx].tolist() |
| ) |
| trg_spans.append(spans) |
|
|
| results = [] |
|
|
| for bidx in range(B): |
|
|
| |
| span2label = { |
| (s, e): id2label['Trg'][t_id] |
| for (s, e, t_id) in trg_spans[bidx] |
| } |
|
|
| for n in range(pred_trg_spans.shape[1]): |
|
|
| s_trg = pred_trg_spans[bidx, n, 0].item() |
| e_trg = pred_trg_spans[bidx, n, 1].item() |
|
|
| |
| if s_trg == 0 and e_trg == 0: |
| continue |
|
|
| if (s_trg, e_trg) not in span2label: |
| continue |
|
|
| trg_label = span2label[(s_trg, e_trg)] |
|
|
| trg_tokens = input_ids[ |
| bidx, |
| s_trg:e_trg + 1 |
| ].tolist() |
|
|
| |
| arg_start_ids = torch.argmax( |
| arg_start_logits[bidx, n], |
| dim=-1 |
| ).tolist() |
|
|
| arg_end_ids = torch.argmax( |
| arg_end_logits[bidx, n], |
| dim=-1 |
| ).tolist() |
|
|
| arg_spans = decode_spans( |
| arg_start_ids, |
| arg_end_ids |
| ) |
|
|
| for s_arg, e_arg, arg_label_id in arg_spans: |
|
|
| arg_label = id2label['Arg'][arg_label_id] |
|
|
| arg_tokens = input_ids[ |
| bidx, |
| s_arg:e_arg + 1 |
| ].tolist() |
|
|
| results.append(( |
| bidx, |
| (tuple(trg_tokens), trg_label), |
| (tuple(arg_tokens), arg_label) |
| )) |
|
|
| return results |
| |
| class Trainer: |
| def __init__( |
| self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0, |
| logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu', |
| schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True |
| ): |
| self.ema_net = None |
| |
| self.training_time = self._time_str_to_seconds(training_time) |
| self.mode = eval_mode |
| self.topk = topk |
| self.device = device |
| self.logging = logging if logging < epochs else 1 |
| self.logging_file = logging_file |
| self.checkpoints_dir = checkpoints_dir |
| self.early_stopping = early_stopping |
| self.eval_from_ratio = eval_from_ratio |
| self.eval_every = eval_every |
| self.save_name = save_name |
| self.save_best = save_best |
| self.save_last = save_last |
| self.return_best = return_best |
| self.return_last = return_last |
| self.max_grad_norm = max_grad_norm |
| self.schedule_in_step = schedule_in_step |
| self.use_ema = use_ema |
| self.ema_from_ratio = ema_from_ratio |
| self.ema_decay = ema_decay |
| |
| self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]] |
| self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0) |
|
|
| def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, id2label=None): |
| if eval_fn is None: |
| if self.mode == "max": |
| eval_fn = lambda *x: -loss_fn(*x) |
| else: |
| eval_fn = lambda *x: loss_fn(*x) |
|
|
| if torch.cuda.device_count() > 1: |
| network = DataParallelProxy(network) |
| network = network.to(self.device) |
| |
| if not start_training_time: |
| start_training_time = time.time() |
| |
| start_ema = int(epochs * self.ema_from_ratio) |
| start_eval = int(epochs * self.eval_from_ratio) |
| |
| if val_loader is None: |
| print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!') |
| else: |
| model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc' |
| start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else '' |
| print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str) |
|
|
| training_log = {} |
| for epoch in range(start_epoch, epochs+start_epoch): |
| if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema: |
| self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device) |
| |
| try: |
| teaching_rate = math.cos(math.pi / 2 * epoch / epochs) |
| train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, teaching_rate) |
| logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch} |
| logging_dict.update(train_loss_epoch_dict) |
| |
| if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0: |
| eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network |
| |
| val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, id2label) |
| update = self._update_best_network(eval_net, val_score, epoch) |
| logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update}) |
| logging_dict.update(val_score_dict) |
| if not self.schedule_in_step and scheduler: |
| scheduler.step() |
| |
| except RuntimeError as e: |
| if "out of memory" in str(e).lower(): |
| print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...") |
| torch.cuda.empty_cache() |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared") |
| |
| train_loader = reduce_batch_size(train_loader, ratio=0.5) |
| if val_loader is not None: |
| val_loader = reduce_batch_size(val_loader, ratio=0.5) |
| |
| logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')} |
| else: |
| raise |
| |
| training_log[epoch] = logging_dict |
| if self.is_early_stopping(epoch): |
| print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...') |
| break |
| if self.logging: |
| if epoch % self.logging == 0: |
| print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict)) |
| else: |
| print(f'{epoch}...', end=' ') |
| |
| if self._at_time_limit(start_training_time): |
| print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}') |
| break |
| |
| if self.logging_file: |
| os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True) |
| with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f: |
| f.write(json.dumps(training_log)) |
| |
| if self.use_ema and self.ema_net is not None: |
| self._save_state_dict(self.ema_net.module) |
| else: |
| self._save_state_dict(network) |
| print(f'[Trainer CallBack] 📢 Kết thúc training.\n') |
|
|
| best_model, last_model = None, None |
| eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network |
| if self.return_best : |
| best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict() |
| best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()} |
| if self.return_last: |
| last_model = eval_net.state_dict() |
| last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()} |
|
|
| del network |
| torch.cuda.empty_cache() |
| gc.collect() |
| return training_log, best_model, last_model |
|
|
| def _time_str_to_seconds(self, time_str): |
| days, hours, minutes, seconds = map(int, time_str.split(":")) |
| return days * 86400 + hours * 3600 + minutes * 60 + seconds |
|
|
| def _update_best_network(self, network, val_score, epoch): |
| topk = max(1, self.topk) |
| self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}]) |
| self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk] |
| if val_score in [x[0] for x in self.best_stage]: |
| return True |
| return False |
|
|
| def is_early_stopping(self, epoch): |
| if self.best_stage[0][1] is None: |
| return False |
| if not self.early_stopping: |
| return False |
| return epoch - self.best_stage[0][1] >= self.early_stopping |
|
|
| def _at_time_limit(self, start_training_time): |
| return time.time() - start_training_time >= self.training_time |
|
|
| def _save_state_dict(self, network): |
| if self.topk <= 0: |
| return |
|
|
| if self.save_best: |
| for r in range(self.topk): |
| os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True) |
| |
| for rank, (score, epoch, state_dict) in enumerate(self.best_stage): |
| if state_dict is None: |
| continue |
| state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()} |
| torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth') |
| if self.save_last: |
| os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True) |
| state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()} |
| torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth') |
| |
| def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn, teaching_rate): |
| network.train() |
| total_loss = 0 |
| total_loss_dict = {} |
| for batch_idx, batch in enumerate(train_loader): |
| optimizer.zero_grad() |
| with torch.autocast(device_type=self.device, dtype=torch.float16): |
| loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn, teaching_rate) |
|
|
| for k, v in loss_dict.items(): |
| t = total_loss_dict.get(k, 0) |
| total_loss_dict[k] = t + v |
| self.grad_scaler.scale(loss).backward() |
| self.grad_scaler.unscale_(optimizer) |
| grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm) |
| |
| self.grad_scaler.step(optimizer) |
| self.grad_scaler.update() |
| if self.schedule_in_step and scheduler: |
| scheduler.step() |
| if self.use_ema and self.ema_net is not None: |
| self.ema_net.update(network) |
| total_loss += loss |
| return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()} |
|
|
| def _eval_epoch(self, network, val_loader, eval_fn, id2label): |
| network.eval() |
| total_score = 0.0 |
| total_score_dict = {} |
| object_lists = None |
| |
| with torch.no_grad(): |
| for batch_idx, batch in enumerate(val_loader): |
| score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, id2label) |
| total_score += score |
|
|
| for k, v in score_dict.items(): |
| t = total_score_dict.get(k, 0) |
| total_score_dict[k] = t + v |
| |
| if objects: |
| if object_lists is None: |
| object_lists = [[] for _ in range(len(objects))] |
| |
| for i, obj in enumerate(objects): |
| object_lists[i].append(obj.detach()) |
| |
| if object_lists is not None: |
| object_arrays = [ |
| torch.concat(obj_list, dim=0).cpu().numpy() |
| for obj_list in object_lists |
| ] |
| else: |
| object_arrays = [] |
| |
| return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays |
|
|
| def _cal_loss(self, network, batch, batch_idx, loss_fn, teaching_rate): |
| |
| input_ids = batch['input_ids'].to(self.device) |
| attention_mask = batch['attention_mask'].to(self.device) |
| trg_spans = batch['trg_spans'].to(self.device) |
| trg_start_labels = batch['trg_start_labels'].to(self.device) |
| trg_end_labels = batch['trg_end_labels'].to(self.device) |
| all_arg_start_labels = batch['all_arg_start_labels'].to(self.device) |
| all_arg_end_labels = batch['all_arg_end_labels'].to(self.device) |
| |
| hidden_states = network.encode(input_ids, attention_mask) |
| trg_start_logits, trg_end_logits = network.get_trg_logits(hidden_states) |
|
|
| choice = random.random() |
| if choice < teaching_rate: |
| pred_trg_spans = trg_spans |
| else: |
| trg_start_ids = torch.argmax(trg_start_logits, dim=-1) |
| trg_end_ids = torch.argmax(trg_end_logits, dim=-1) |
| pred_trg_spans = decode_spans_batch(trg_start_ids, trg_end_ids) |
|
|
| trg_repr = get_span_repr(hidden_states, pred_trg_spans) |
| |
| trg_repr = network.trg_repr_proj(trg_repr) |
| arg_start_logits, arg_end_logits = network.get_arg_logits(hidden_states, trg_repr) |
| |
| pred_arg_start_labels = map_arg_labels(all_arg_start_labels, trg_spans, pred_trg_spans) |
| pred_arg_end_labels = map_arg_labels(all_arg_end_labels, trg_spans, pred_trg_spans) |
| |
| loss_dict = loss_fn( |
| trg_start_logits, trg_start_labels, |
| trg_end_logits, trg_end_labels, |
| arg_start_logits, pred_arg_start_labels, |
| arg_end_logits, pred_arg_end_labels, |
| ) |
| return loss_dict['total'], loss_dict |
|
|
| def _cal_val_score(self, network, batch, batch_idx, eval_fn, id2label): |
| |
| input_ids = batch['input_ids'].to(self.device) |
| attention_mask = batch['attention_mask'].to(self.device) |
| gold_events = batch['gold_events'] |
|
|
| B, _, _ = input_ids.shape |
| |
| hidden_states = network.encode(input_ids, attention_mask) |
| trg_start_logits, trg_end_logits = network.get_trg_logits(hidden_states) |
| |
| trg_start_ids = torch.argmax(trg_start_logits, dim=-1) |
| trg_end_ids = torch.argmax(trg_end_logits, dim=-1) |
| pred_trg_spans = decode_spans_batch(trg_start_ids, trg_end_ids) |
| trg_repr = get_span_repr(hidden_states, pred_trg_spans) |
| |
| trg_repr = network.trg_repr_proj(trg_repr) |
| arg_start_logits, arg_end_logits = network.get_arg_logits(hidden_states, trg_repr) |
| |
| pred_ids = extract_arguments(input_ids.reshape(B, -1), trg_start_logits, trg_end_logits, arg_start_logits, arg_end_logits, pred_trg_spans, id2label) |
| pred_ids = list_to_tuple(pred_ids) |
| |
| gold_ids = list_to_tuple(gold_events) |
| |
| score_dict = eval_fn(pred_ids, gold_ids) |
| return score_dict['f1'], score_dict, [] |
|
|
| |
| class PhoBERTSpanAligner: |
| def __init__(self, tokenizer, max_len): |
| self.tokenizer = tokenizer |
| self.max_len = max_len |
|
|
| |
| def extract_spans(self, sample): |
| trigger_spans, arg_spans = [], [] |
| |
| for event in sample["events"]: |
| trigger_type = event["label"] |
| spans = [tuple(event["offset"])] |
| trigger_spans.append({ |
| "spans": spans, |
| "label": trigger_type |
| }) |
| event_arg_spans = [] |
| for arg in event['arguments']: |
| arg_type = arg["role"] |
| spans = [tuple(arg["offset"])] |
| event_arg_spans.append({ |
| "spans": spans, |
| "label": arg_type |
| }) |
| arg_spans.append(event_arg_spans) |
| |
| return trigger_spans, arg_spans |
|
|
| |
| def build_word_offsets(self, text, words): |
| offsets = [] |
| pointer = 0 |
|
|
| for word in words: |
| start = text.find(word, pointer) |
| end = start + len(word) |
| offsets.append((start, end)) |
| pointer = end |
|
|
| return offsets |
|
|
| |
| def char_span_to_word_span(self, word_offsets, start, end): |
| start_word = None |
| end_word = None |
|
|
| for i, (w_start, w_end) in enumerate(word_offsets): |
| if w_start <= start < w_end: |
| start_word = i |
| if w_start < end <= w_end: |
| end_word = i |
|
|
| return start_word, end_word |
|
|
| |
| def word_to_subword_map(self, words): |
| mapping = [] |
| subword_index = 1 |
|
|
| for word in words: |
| sub_tokens = self.tokenizer.tokenize(word) |
| start = subword_index |
| end = subword_index + len(sub_tokens) - 1 |
| mapping.append((start, end)) |
| subword_index += len(sub_tokens) |
|
|
| return mapping |
|
|
| |
| def span_to_subword(self, word_offsets, word_subword_map, spans): |
| sub_spans = [] |
|
|
| for span_start, span_end in spans: |
| w_start, w_end = self.char_span_to_word_span( |
| word_offsets, span_start, span_end |
| ) |
| if w_start is None or w_end is None: |
| continue |
|
|
| sub_start = word_subword_map[w_start][0] |
| sub_end = word_subword_map[w_end][1] |
| sub_spans.append((sub_start, sub_end)) |
|
|
| return sub_spans |
|
|
| def extract_valid_spans(self, sub_spans): |
| valid_spans = [] |
| for s, e in sub_spans: |
| if s < 0 or e < 0 or s >= self.max_len or e >= self.max_len or s > e: |
| continue |
| valid_spans.append((s, e)) |
| return valid_spans |
|
|
| def encode(self, sample): |
| text = sample["text"] |
| triggers, arguments = self.extract_spans(sample) |
| |
| |
| words = word_tokenize(text) |
| sentence = " ".join(words) |
| |
| |
| word_offsets = self.build_word_offsets(text, words) |
| word_subword_map = self.word_to_subword_map(words) |
| |
| |
| encoding = self.tokenizer( |
| sentence, |
| max_length=self.max_len, |
| truncation=True, |
| padding="max_length", |
| return_tensors="pt" |
| ) |
| input_ids = encoding["input_ids"][0] |
| attention_mask = encoding["attention_mask"][0] |
| |
| |
| triggers_gold_spans = [] |
| arguments_gold_spans = [] |
| |
| for trg, args in zip(triggers, arguments): |
| label = trg["label"] |
| |
| sub_spans = self.span_to_subword( |
| word_offsets, |
| word_subword_map, |
| trg["spans"] |
| ) |
| valid_spans = self.extract_valid_spans(sub_spans) |
| if len(valid_spans) == 0: |
| continue |
| triggers_gold_spans.append((tuple(valid_spans), label)) |
|
|
| trg_args_gold_spans = [] |
| for arg in args: |
| label = arg["label"] |
| |
| sub_spans = self.span_to_subword( |
| word_offsets, |
| word_subword_map, |
| arg["spans"] |
| ) |
| valid_spans = self.extract_valid_spans(sub_spans) |
| if len(valid_spans) == 0: |
| continue |
| trg_args_gold_spans.append((tuple(valid_spans), label)) |
| arguments_gold_spans.append(tuple(trg_args_gold_spans)) |
| |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "triggers_gold_spans": triggers_gold_spans, |
| "arguments_gold_spans": arguments_gold_spans, |
| } |
|
|
| def generate_candidate_spans(seq_len, max_span_len): |
| spans = [] |
| for i in range(1, seq_len+1): |
| for j in range(i, min(i+max_span_len, seq_len+1)): |
| spans.append((i, j)) |
| return spans |
| |
| class KLTNDataset(Dataset): |
| def __init__(self, all_data, using_idxes, label2id, tokenizer, max_len, max_n_parts): |
| super().__init__() |
| self.tokenizer = tokenizer |
| self.aligner = PhoBERTSpanAligner(tokenizer, max_len*max_n_parts) |
| self.all_data = all_data |
| self.using_idxes = using_idxes |
| self.label2id = label2id |
| self.max_len = max_len |
| self.max_n_parts = max_n_parts |
|
|
| def __len__(self): |
| return len(self.using_idxes) |
|
|
| def __getitem__(self, idx): |
| ridx = self.using_idxes[idx] |
| sample = self.all_data[ridx] |
| result = self.aligner.encode(sample) |
| |
| input_ids = result["input_ids"].squeeze(0) |
| attention_mask = result["attention_mask"].squeeze(0) |
| triggers_gold_spans = result["triggers_gold_spans"] |
| arguments_gold_spans = result["arguments_gold_spans"] |
| |
| |
| all_trg_spans = torch.tensor([list(trg_spans[0]) for trg_spans, _ in triggers_gold_spans], dtype=torch.long) if triggers_gold_spans else torch.empty(0, 2, dtype=torch.long) |
| gold_events = [] |
| trg_start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100) |
| trg_end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100) |
| all_arg_start_labels, all_arg_end_labels = [], [] |
| for (trg_spans, trg_label), args in zip(triggers_gold_spans, arguments_gold_spans): |
| s, e = trg_spans[0] |
| |
| trg_start_labels[s] = self.label2id['Trg'][f'{trg_label}'] |
| trg_end_labels[e] = self.label2id['Trg'][f'{trg_label}'] |
| |
| event = [(tuple(input_ids[s:e+1].tolist()), trg_label)] |
| |
| arg_start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100) |
| arg_end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100) |
| for arg_spans, arg_label in args: |
| s, e = arg_spans[0] |
| |
| arg_start_labels[s] = self.label2id['Arg'][f'{arg_label}'] |
| arg_end_labels[e] = self.label2id['Arg'][f'{arg_label}'] |
| |
| event.append((tuple(input_ids[s:e+1].tolist()), arg_label)) |
| all_arg_start_labels.append(arg_start_labels) |
| all_arg_end_labels.append(arg_end_labels) |
| |
| gold_events.append(event) |
| |
| input_ids = input_ids.reshape(self.max_n_parts, self.max_len) |
| attention_mask = attention_mask.reshape(self.max_n_parts, self.max_len) |
|
|
| n_valid_parts = math.ceil(attention_mask.sum().item() / self.max_len) |
| input_ids = input_ids[:n_valid_parts] |
| attention_mask = attention_mask[:n_valid_parts] |
| trg_start_labels = trg_start_labels[:n_valid_parts*self.max_len] |
| trg_end_labels = trg_end_labels[:n_valid_parts*self.max_len] |
| all_arg_start_labels = torch.stack([arg_labels[:n_valid_parts*self.max_len] for arg_labels in all_arg_start_labels], dim=0) if all_arg_start_labels else torch.empty(0, n_valid_parts*self.max_len) |
| all_arg_end_labels = torch.stack([arg_labels[:n_valid_parts*self.max_len] for arg_labels in all_arg_end_labels], dim=0) if all_arg_end_labels else torch.empty(0, n_valid_parts*self.max_len) |
| |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "trg_spans": all_trg_spans, |
| "trg_start_labels": trg_start_labels, |
| "trg_end_labels": trg_end_labels, |
| "all_arg_start_labels": all_arg_start_labels, |
| "all_arg_end_labels": all_arg_end_labels, |
| "gold_events": gold_events, |
| } |
|
|
| def _pad_batch(tensor_list, pad_value=0): |
| """ |
| tensor_list: list of tensors |
| mỗi tensor shape: (Nk, n_parts_i, max_len_i) |
| |
| return: |
| padded tensor shape: (B, max_Nk, max_n_parts, max_len) |
| """ |
|
|
| |
| max_Nk = max(t.size(0) for t in tensor_list) |
| max_n_parts = max(t.size(1) for t in tensor_list) |
| max_len = max(t.size(2) for t in tensor_list) |
|
|
| padded = [] |
|
|
| for t in tensor_list: |
| Nk, n_parts_i, max_len_i = t.shape |
|
|
| |
| if n_parts_i < max_n_parts or max_len_i < max_len: |
| new_t = t.new_full( |
| (Nk, max_n_parts, max_len), |
| pad_value |
| ) |
| new_t[:, :n_parts_i, :max_len_i] = t |
| t = new_t |
|
|
| |
| if Nk < max_Nk: |
| pad_tensor = t.new_full( |
| (max_Nk - Nk, max_n_parts, max_len), |
| pad_value |
| ) |
| t = torch.cat([t, pad_tensor], dim=0) |
|
|
| padded.append(t) |
|
|
| return torch.stack(padded) |
|
|
| def collate_fn(batch): |
| gold_events = [] |
| for bidx, b in enumerate(batch): |
| for event in b['gold_events']: |
| trg = event[0] |
| if len(event) > 1: |
| for arg in event[1:]: |
| gold_events.append([bidx, trg, arg]) |
| else: |
| gold_events.append([bidx, trg, (tuple([]), 0)]) |
|
|
| input_ids = [b["input_ids"].unsqueeze(-1) for b in batch] |
| attention_mask = [b["attention_mask"].unsqueeze(-1) for b in batch] |
| trg_spans = [b["trg_spans"].unsqueeze(-1) for b in batch] |
| trg_start_labels = [b["trg_start_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch] |
| trg_end_labels = [b["trg_end_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch] |
| all_arg_start_labels = [b["all_arg_start_labels"].unsqueeze(-1) for b in batch] |
| all_arg_end_labels = [b["all_arg_end_labels"].unsqueeze(-1) for b in batch] |
|
|
| |
| input_ids = _pad_batch(input_ids, pad_value=0).squeeze(-1) |
| attention_mask = _pad_batch(attention_mask, pad_value=0).squeeze(-1) |
| trg_spans = _pad_batch(trg_spans, pad_value=0).squeeze(-1) |
| trg_start_labels = _pad_batch(trg_start_labels, pad_value=-100).squeeze(-1).squeeze(-1) |
| trg_end_labels = _pad_batch(trg_end_labels, pad_value=-100).squeeze(-1).squeeze(-1) |
| all_arg_start_labels = _pad_batch(all_arg_start_labels, pad_value=-100).squeeze(-1) |
| all_arg_end_labels = _pad_batch(all_arg_end_labels, pad_value=-100).squeeze(-1) |
| |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "trg_spans": trg_spans, |
| "trg_start_labels": trg_start_labels, |
| "trg_end_labels": trg_end_labels, |
| "all_arg_start_labels": all_arg_start_labels, |
| "all_arg_end_labels": all_arg_end_labels, |
| "gold_events": gold_events, |
| } |
|
|
| |
| def shift_bidx(spans, batch_idx): |
| shifted = [] |
| for bidx, trg, arg in spans: |
| new_bidx = bidx + batch_idx * batch_size |
| shifted.append((new_bidx, trg, arg)) |
| return shifted |
|
|
| def refactor_events(events, save_dict): |
| trg_i, trg_c, arg_i, arg_c, soft, strict_dict = [], [], [], [], [], {} |
| for bidx, (trg_ids, trg_lb), (arg_k_ids, arg_k_lb) in events: |
| if (bidx, trg_ids) not in trg_i: |
| trg_i.append((bidx, trg_ids)) |
| |
| if (bidx, (trg_ids, trg_lb)) not in trg_c: |
| trg_c.append((bidx, (trg_ids, trg_lb))) |
| |
| if (bidx, trg_ids, arg_k_ids) not in arg_i: |
| arg_i.append((bidx, trg_ids, arg_k_ids)) |
| |
| if (bidx, trg_ids, (arg_k_ids, arg_k_lb)) not in arg_c: |
| arg_c.append((bidx, trg_ids, (arg_k_ids, arg_k_lb))) |
|
|
| if (bidx, (trg_ids, trg_lb), (arg_k_ids, arg_k_lb)) not in soft: |
| soft.append((bidx, (trg_ids, trg_lb), (arg_k_ids, arg_k_lb))) |
|
|
| if bidx not in strict_dict: |
| strict_dict[bidx] = {} |
| if (trg_ids, trg_lb) not in strict_dict[bidx]: |
| strict_dict[bidx][(trg_ids, trg_lb)] = [] |
| strict_dict[bidx][(trg_ids, trg_lb)].append((arg_k_ids, arg_k_lb)) |
|
|
| strict = [] |
| for bidx, trg_dict in strict_dict.items(): |
| for trg, args in trg_dict.items(): |
| strict.append((bidx, trg, frozenset(args))) |
|
|
| save_dict['Trg-I'].extend(trg_i) |
| save_dict['Trg-C'].extend(trg_c) |
| save_dict['Arg-I'].extend(arg_i) |
| save_dict['Arg-C'].extend(arg_c) |
| save_dict['Soft-Event'].extend(soft) |
| save_dict['Strict-Event'].extend(strict) |
|
|
| def test(network, state_dicts, test_loader, eval_fn, analyzer, device, id2label, tokenizer): |
| if torch.cuda.device_count() > 1: |
| network = DataParallelProxy(network) |
| network = network.to(device) |
| network.eval() |
|
|
| eval_types = ['Trg-I', 'Trg-C', 'Arg-I', 'Arg-C', 'Soft-Event', 'Strict-Event'] |
| |
| all_pred = {eval_type: [] for eval_type in eval_types} |
| all_gold = {eval_type: [] for eval_type in eval_types} |
|
|
| list_input_ids = [] |
|
|
| with torch.no_grad(): |
| for batch_idx, batch in enumerate(test_loader): |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| gold_events = batch['gold_events'] |
|
|
| B, _, _ = input_ids.shape |
| list_input_ids.extend(input_ids.reshape(B, -1).tolist()) |
| |
| list_trg_start_logits = [] |
| list_trg_end_logits = [] |
| list_hidden_states = [] |
| list_arg_start_logits = [] |
| list_arg_end_logits = [] |
| |
| for sd in state_dicts: |
| if torch.cuda.device_count() > 1: |
| network.module.load_state_dict(sd) |
| else: |
| network.load_state_dict(sd) |
| |
| hidden_states = network.encode(input_ids, attention_mask) |
| trg_start_logits, trg_end_logits = network.get_trg_logits(hidden_states) |
| list_trg_start_logits.append(trg_start_logits) |
| list_trg_end_logits.append(trg_end_logits) |
| list_hidden_states.append(hidden_states) |
|
|
| ensemble_trg_start_logits = torch.stack(list_trg_start_logits, dim=0).mean(dim=0) |
| ensemble_trg_end_logits = torch.stack(list_trg_end_logits, dim=0).mean(dim=0) |
| trg_start_ids = torch.argmax(ensemble_trg_start_logits, dim=-1) |
| trg_end_ids = torch.argmax(ensemble_trg_end_logits, dim=-1) |
| pred_trg_spans = decode_spans_batch(trg_start_ids, trg_end_ids) |
| |
| for sd, hidden_states in zip(state_dicts, list_hidden_states): |
| if torch.cuda.device_count() > 1: |
| network.module.load_state_dict(sd) |
| else: |
| network.load_state_dict(sd) |
| |
| trg_repr = get_span_repr(hidden_states, pred_trg_spans) |
| trg_repr = network.trg_repr_proj(trg_repr) |
| arg_start_logits, arg_end_logits = network.get_arg_logits(hidden_states, trg_repr) |
|
|
| list_arg_start_logits.append(arg_start_logits) |
| list_arg_end_logits.append(arg_end_logits) |
| |
| ensemble_arg_start_logits = torch.stack(list_arg_start_logits, dim=0).mean(dim=0) |
| ensemble_arg_end_logits = torch.stack(list_arg_end_logits, dim=0).mean(dim=0) |
|
|
| pred_events = extract_arguments( |
| input_ids.reshape(B, -1), |
| ensemble_trg_start_logits, ensemble_trg_end_logits, |
| ensemble_arg_start_logits, ensemble_arg_end_logits, |
| pred_trg_spans, id2label |
| ) |
| pred_events = shift_bidx(pred_events, batch_idx) |
| refactor_events(pred_events, all_pred) |
| |
| gold_events = shift_bidx(gold_events, batch_idx) |
| refactor_events(gold_events, all_gold) |
|
|
| |
| final_score = {} |
| for eval_type in eval_types: |
| score = eval_fn(list_to_tuple(all_pred[eval_type]), list_to_tuple(all_gold[eval_type])) |
| final_score[eval_type] = score |
|
|
| analyze_result = analyzer.analyze(list_to_tuple(all_pred['Trg-I']), list_to_tuple(all_gold['Trg-I'])) |
|
|
| |
| predictions = [] |
| for input_ids in list_input_ids: |
| predictions.append([tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)]) |
| for event in all_pred['Strict-Event']: |
| bidx = event[0] |
| trg = tokenizer.decode(event[1][0], skip_special_tokens=True, clean_up_tokenization_spaces=True) |
| trg_lb = event[1][1] |
| predictions[bidx].append((trg, trg_lb)) |
|
|
| for arg_infor in event[2]: |
| arg = tokenizer.decode(arg_infor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) |
| arg_lb = arg_infor[1] |
| predictions[bidx].append((arg, arg_lb)) |
|
|
| return final_score, analyze_result, predictions |
|
|
| |
| with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f: |
| data_train = json.load(f) |
| |
| with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f: |
| data_test = json.load(f) |
|
|
| print('Train:', len(data_train)) |
| print('Test:', len(data_test)) |
|
|
| |
| trigger_types = ['O'] + sorted(list(set([e['label'] for d in data_train + data_test for e in d['events']]))) |
| |
| trigger_label2id = {l: i for i, l in enumerate(trigger_types)} |
| trigger_id2label = {i: l for l, i in trigger_label2id.items()} |
|
|
| argument_types = ['O'] + sorted(list(set([a['role'] for d in data_train + data_test for e in d['events'] for a in e['arguments']]))) |
| |
| argument_label2id = {l: i for i, l in enumerate(argument_types)} |
| argument_id2label = {i: l for l, i in argument_label2id.items()} |
|
|
| label2id = { |
| 'Trg': trigger_label2id, |
| 'Arg': argument_label2id, |
| } |
|
|
| id2label = { |
| 'Trg': trigger_id2label, |
| 'Arg': argument_id2label, |
| } |
|
|
| |
| zero_events_idxes = [] |
| for idx, d in enumerate(data_train): |
| if len(d['events']) == 0: |
| zero_events_idxes.append(idx) |
|
|
| n_zero_events_samples = len(zero_events_idxes) |
| n_has_events_samples = len(data_train) - n_zero_events_samples |
|
|
| random.seed(42) |
| k = min(int(n_has_events_samples * zero_events_rate), len(zero_events_idxes)) |
| sampled_zero_events_idxes = random.sample(zero_events_idxes, k) |
|
|
| new_data_train = [] |
| for idx, d in enumerate(data_train): |
| if len(d['events']) == 0: |
| if idx in sampled_zero_events_idxes: |
| new_data_train.append(d) |
| else: |
| new_data_train.append(d) |
| data_train = new_data_train |
|
|
| print('Train:', len(data_train)) |
|
|
| |
| if debug_only: |
| data_train = data_train[:20] |
| data_test = data_test[:20] |
|
|
| print('Train:', len(data_train)) |
| print('Test:', len(data_test)) |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(backbone_model_name) |
|
|
| |
| print('Experiment name:', state_dict_save_name) |
|
|
| |
| if not test_only: |
| full_idxes = np.array(range(len(data_train))) |
| training_logs, best_models, last_models = [], [], [] |
| start_training_time = time.time() |
| for seed in SEEDS: |
| kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed) |
| for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)): |
| if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx: |
| continue |
| set_seed(seed) |
| |
| train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx] |
| |
| trainset = KLTNDataset(data_train, train_idxes, label2id, tokenizer, **train_memory_params) |
| valset = KLTNDataset(data_train, val_idxes, label2id, tokenizer, **val_memory_params) |
| |
| generator = torch.Generator() |
| generator.manual_seed(seed) |
| train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params) |
| val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params) |
| |
| my_model = IEModel( |
| num_trg_labels=len(trigger_label2id), |
| num_arg_labels=len(argument_label2id), |
| **model_params |
| ) |
| total_params = sum(p.numel() for p in my_model.parameters()) |
| print(f"Total params: {total_params:,}") |
| |
| |
| encoder_params = set(map(id, my_model.encoder.parameters())) |
| other_params = [ |
| p for p in my_model.parameters() |
| if id(p) not in encoder_params |
| ] |
| optimizer = optim.AdamW([ |
| {"params": my_model.encoder.parameters(), "lr": 2e-5}, |
| {"params": other_params} |
| ], lr=5e-4) |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6) |
| |
| loss_fn = CustomLoss( |
| **loss_func_params |
| ) |
| eval_fn = CustomEvalFn(**eval_func_params) |
| trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}' |
| trainer = Trainer(**trainer_params) |
| |
| print(f'Start Training Fold {fold_idx}...') |
| training_log, best_model, last_model = trainer.fit( |
| my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, eval_fn, |
| start_epoch=1, start_training_time=start_training_time, id2label=id2label |
| ) |
| |
| training_logs.append(training_log) |
| best_models.append(best_model) |
| last_models.append(last_model) |
|
|
| |
| def load_all_state_dicts(folder): |
| files = [] |
| |
| for file in os.listdir(folder): |
| if file.endswith(".pt") or file.endswith(".pth"): |
| m = re.search(r"f(\d+)", file) |
| if m: |
| fold = int(m.group(1)) |
| files.append((fold, file)) |
|
|
| |
| files.sort(key=lambda x: x[0]) |
|
|
| state_dicts = [] |
| for fold, file in files: |
| path = os.path.join(folder, file) |
| print(f"Loading fold {fold}: {file}") |
| state_dict = torch.load(path, map_location="cpu") |
| state_dicts.append(state_dict) |
|
|
| return state_dicts |
|
|
| if test_only: |
| snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"]) |
| get_ipython().system('rm -rf .cache .gitattributes') |
| |
| best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s") |
| last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts") |
|
|
| |
| os.makedirs(f'{checkpoints_dir}/results', exist_ok=True) |
| testset = KLTNDataset(data_test, range(len(data_test)), label2id, tokenizer, **val_memory_params) |
| generator = torch.Generator() |
| test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params) |
| eval_fn = CustomEvalFn(**eval_func_params) |
| analyzer = SpanErrorAnalyzer() |
| my_model = IEModel( |
| num_trg_labels=len(trigger_label2id), |
| num_arg_labels=len(argument_label2id), |
| **model_params |
| ) |
| total_params = sum(p.numel() for p in my_model.parameters()) |
| print(f"Total params: {total_params:,}") |
|
|
| |
| start_time = time.time() |
|
|
| best_score, best_analyze_result, best_pred_test = test(my_model, best_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer) |
| last_score, last_analyze_result, last_pred_test = test(my_model, last_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer) |
|
|
| result_test = {"Best model": best_score, "Last model": last_score} |
| analyze_result = {"Best model": best_analyze_result, "Last model": last_analyze_result} |
| analyze_result_sumary = {"Best model": best_analyze_result['summary'], "Last model": last_analyze_result['summary']} |
| pred_test = {"Best model": best_pred_test, "Last model": last_pred_test} |
|
|
| with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f: |
| json.dump(result_test, f, ensure_ascii=False, indent=2) |
|
|
| with open(f"{checkpoints_dir}/results/{state_dict_save_name}_error_analyze_result.json", "w", encoding="utf-8") as f: |
| json.dump(analyze_result, f, ensure_ascii=False, indent=2) |
|
|
| with open(f"{checkpoints_dir}/results/{state_dict_save_name}_pred_test.json", "w", encoding="utf-8") as f: |
| json.dump(pred_test, f, ensure_ascii=False, indent=2) |
| |
| print('Test:', time.time() - start_time, 's --> Done!') |
| print(json.dumps(analyze_result_sumary, ensure_ascii=False, indent=4)) |
|
|
| |
| best_pred_test[:10] |
|
|
| |
| last_pred_test[:10] |
|
|
| |
| def dict_to_df(data): |
| row_tuples = [] |
| row_values = [] |
|
|
| metrics = ["precision", "recall", "f1"] |
|
|
| |
| first_model = next(iter(data.values())) |
|
|
| |
| eval_keys = list(first_model.keys()) |
|
|
| for eval_key in eval_keys: |
| row_tuples.append(eval_key) |
| row = {} |
|
|
| for model_name, model_data in data.items(): |
| for metric in metrics: |
| row[(model_name, metric)] = model_data[eval_key][metric] |
|
|
| row_values.append(row) |
|
|
| |
| df = pd.DataFrame(row_values) |
|
|
| |
| df.columns = pd.MultiIndex.from_tuples(df.columns) |
|
|
| |
| df.index = pd.Index(row_tuples, name="evaluation") |
|
|
| |
| sort_keys = [] |
| if ("Best model", "f1") in df.columns: |
| sort_keys.append(("Best model", "f1")) |
| if ("Last model", "f1") in df.columns: |
| sort_keys.append(("Last model", "f1")) |
|
|
| if sort_keys: |
| df = df.sort_values(by=sort_keys, ascending=False) |
|
|
| return df |
|
|
| result_test_df = dict_to_df(result_test) |
| result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx") |
| result_test_df |
|
|
| |
| key = ("Best model", "f1") |
| result_test_df_best = result_test_df.sort_values(by=key, ascending=False).groupby(level="evaluation").head(1) |
| result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx") |
| result_test_df_best |
|
|
| |
| def get_avg_best_score(logs): |
| return float(np.mean([list(log.values())[-1]['best_score'] for log in logs])) |
| |
| def get_avg_log(logs, epochs): |
| avg_log = {} |
|
|
| for epoch in range(1, epochs + 1): |
| val_score = 0.0 |
| train_loss = 0.0 |
| n_eval = 0 |
|
|
| for idx in range(len(logs)): |
| log = logs[idx].get(epoch, logs[idx].get(str(epoch))) |
| if log is None: |
| continue |
|
|
| val_score += log.get('val_score', 0.0) |
| train_loss += log.get('train_loss', 0.0) |
| n_eval += 1 |
|
|
| if n_eval == 0: |
| continue |
|
|
| avg_log[epoch] = { |
| 'train_loss': train_loss / n_eval, |
| 'val_score': val_score / n_eval if val_score != 0 else float('inf') |
| } |
|
|
| return avg_log |
|
|
| def parse_label_key(label: str): |
| try: |
| first = float(label.split('_', 1)[0]) |
| last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0]) |
| return first, last |
| except: |
| return (0, 0) |
|
|
| def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)): |
| fig, axes = plt.subplots(1, 2, figsize=figsize) |
|
|
| |
| for name, log in logs_dict.items(): |
| epochs = sorted(log.keys()) |
| train_loss = [log[e]['train_loss'] for e in epochs] |
| axes[0].plot(epochs, train_loss, label=name) |
|
|
| axes[0].set_xlabel('Epoch') |
| axes[0].set_ylabel('Train Loss') |
| axes[0].set_title('Training Loss') |
| axes[0].grid(True) |
|
|
| |
| for name, log in logs_dict.items(): |
| epochs = sorted(log.keys()) |
| val_score = [log[e]['val_score'] for e in epochs] |
| axes[1].plot(epochs, val_score, label=name) |
|
|
| axes[1].set_xlabel('Epoch') |
| axes[1].set_ylabel('Validation Score') |
| axes[1].set_title('Validation Score') |
| axes[1].grid(True) |
|
|
| |
| handles, labels = axes[0].get_legend_handles_labels() |
| pairs = list(zip(handles, labels)) |
| pairs_sorted = sorted( |
| pairs, |
| key=lambda x: parse_label_key(x[1]) |
| ) |
| handles_sorted, labels_sorted = zip(*pairs_sorted) |
| |
| axes[0].legend( |
| handles_sorted, |
| labels_sorted, |
| loc='center left', |
| bbox_to_anchor=(1.01, 0.5), |
| borderaxespad=0. |
| ) |
|
|
| plt.tight_layout(rect=[0, 0, 1, 1]) |
|
|
| if save_path is not None: |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
| plt.show() |
|
|
| |
| if not test_only: |
| snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*.json"]) |
| get_ipython().system('rm -rf .cache .gitattributes') |
|
|
| |
| if not test_only: |
| experiments = {} |
| for experiment in os.listdir(pretrained_dir): |
| if '.virtual_documents' in experiment: |
| continue |
| experiment_logs = [] |
| try: |
| for seed in SEEDS: |
| for fold_idx in range(nfolds): |
| with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f: |
| experiment_log = json.load(f) |
| experiment_logs.append(experiment_log) |
| except: |
| pass |
| experiments[experiment] = get_avg_log(experiment_logs, 1000) |
| experiments[state_dict_save_name] = get_avg_log(training_logs, 1000) |
|
|
| |
| if not test_only: |
| score = get_avg_best_score(training_logs) |
| state_dict_save_name, score |
|
|
| |
| if not test_only: |
| plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5)) |
|
|
|
|