# %% [code] get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch]') # %% [code] import warnings warnings.filterwarnings('ignore') import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, TensorDataset, DataLoader import torch.nn.functional as F import albumentations as albu from transformers import AutoTokenizer, AutoModel import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from positional_encodings.torch_encodings import PositionalEncoding1D from sklearn.metrics import f1_score from sklearn.preprocessing import MinMaxScaler, StandardScaler from scipy.spatial.transform import Rotation as R from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold from sklearn.metrics import precision_recall_fscore_support from timm.utils import ModelEmaV3 import timm import os import gc import json from pathlib import Path import pickle from tqdm.auto import tqdm import copy import numpy as np import pandas as pd import polars as pl from PIL import Image import time from tqdm import tqdm from matplotlib import pyplot as plt import seaborn as sns from multiprocessing import Manager as MemoryManager from functools import lru_cache import shutil import glob import cv2 import random import re import joblib import math from huggingface_hub import HfApi, snapshot_download import evaluate from underthesea import word_tokenize as vi_tokenize_tool import spacy en_tokenize_tool = spacy.load("en_core_web_sm") from collections import defaultdict, Counter # %% [code] # Global config SEEDS = [26092004] topk = 1 nfolds = 5 only_fold_idx = 0 test_only = 0 debug_only = 0 # Config thư mục dataset = 'kltn/only_actions' # vhe, bkee, casie, kltn/only_issues, kltn/only_actions root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' ## Thư mục chứa file train, val, test train_dir = f'{root_dir}' # val_dir = f'{root_dir}/val' test_dir = f'{root_dir}' # Config checkpoints # Config training epochs = 18 if not debug_only else 2 batch_size = 32 device = "cuda" if torch.cuda.is_available() else "cpu" # # Thêm biến toàn cục nào đó vào đây repo_name = 'SS3M/kltn-experiments' state_dict_save_name = "1_pointer_base_actions_4" checkpoints_dir = state_dict_save_name pretrained_dir = "/kaggle/working" os.makedirs(f'{checkpoints_dir}', exist_ok=True) backbone_model_name = "bert-base-uncased" if dataset == "casie" else "vinai/phobert-base" word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == "casie" else vi_tokenize_tool(text) max_len_dict = { 'kltn/only_issues': 52, 'kltn/only_actions': 69, 'vhe': 51, 'bkee': 62, 'casie': 40, } zero_events_rate_dict = { 'kltn/only_issues': 0, 'kltn/only_actions': 0.2, 'vhe': 1000, # mean keep all zero-events samples 'bkee': 1000, 'casie': 1000, } max_len = max_len_dict[dataset] max_n_parts = 1 max_span_len = 14 zero_events_rate = zero_events_rate_dict[dataset] # Trainer trainer_params = { "training_time": "00:11:30:00", "eval_mode": "max", "topk": topk, "save_name": state_dict_save_name, "save_best": True, "save_last": True, "device": device, "logging": True, "logging_file": True, "checkpoints_dir": checkpoints_dir, "early_stopping": 30, "eval_from_ratio": 0.4, "eval_every": 1, "schedule_in_step": False, "use_ema": True, "ema_from_ratio": 0.3, "ema_decay": 0.9995, "max_grad_norm": 200.0, "return_best": True, "return_last": True, } # Memory train_memory_params = { 'max_len': max_len, 'max_n_parts': max_n_parts, } val_memory_params = { 'max_len': max_len, 'max_n_parts': max_n_parts, } # Data Loader def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) train_loader_params = { 'batch_size': batch_size, 'shuffle': True, 'pin_memory':True, 'num_workers': 2, 'drop_last': False, 'worker_init_fn': seed_worker, 'persistent_workers': False, } val_loader_params = { 'batch_size': batch_size, 'shuffle': False, 'pin_memory':True, 'num_workers': 1, 'drop_last': False, 'worker_init_fn': seed_worker, 'persistent_workers': False, } # Model model_params = { 'backbone_model_name': backbone_model_name, } # Loss Func loss_func_params = { 'lambda_trg_ce': 1.0, 'lambda_arg_ce': 1.0, } eval_func_params = {} # Optim optim_params = { 'name': 'AdamW', 'lr': 1e-4, 'weight_decay': 1e-4, } scheduler_params = { 'name': 'CosineAnnealingLR', 'T_max': 20, # Số epoch để hoàn thành một chu kỳ giảm LR 'eta_min': 1e-6 # Learning rate nhỏ nhất trong chu kỳ } # %% [code] def set_seed(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if using multi-GPU torch.use_deterministic_algorithms(False) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False os.environ['PYTHONHASHSEED'] = str(seed) # %% [code] class CustomLoss(nn.Module): def __init__( self, lambda_trg_ce=1.0, lambda_arg_ce=1.0, ): super().__init__() self.lambda_trg_ce = lambda_trg_ce self.lambda_arg_ce = lambda_arg_ce self.ce = nn.CrossEntropyLoss(ignore_index=-100) def forward( self, trg_start_logits, trg_start_labels, trg_end_logits, trg_end_labels, arg_start_logits, pred_arg_start_labels, arg_end_logits, pred_arg_end_labels, ): device = trg_start_logits.device # ===== TRG START CE ===== B, N, C = trg_start_logits.shape trg_start_logits_flat = trg_start_logits.view(B * N, C) trg_start_labels_flat = trg_start_labels.view(-1) trg_start_loss = self.ce( trg_start_logits_flat, trg_start_labels_flat ) # ===== TRG END CE ===== B, N, C = trg_end_logits.shape trg_end_logits_flat = trg_end_logits.view(B * N, C) trg_end_labels_flat = trg_end_labels.view(-1) trg_end_loss = self.ce( trg_end_logits_flat, trg_end_labels_flat ) # ===== ARG CE ===== B, K, M, C = arg_start_logits.shape arg_start_logits_flat = arg_start_logits.view(B * K * M, C) arg_start_labels_flat = pred_arg_start_labels.view(-1) arg_mask = (arg_start_labels_flat != -100) if arg_mask.sum() == 0: arg_start_loss = torch.tensor(0.0, device=device) else: arg_start_loss = self.ce(arg_start_logits_flat, arg_start_labels_flat) # (B*K*M,) B, K, M, C = arg_end_logits.shape arg_end_logits_flat = arg_end_logits.view(B * K * M, C) arg_end_labels_flat = pred_arg_end_labels.view(-1) arg_mask = (arg_end_labels_flat != -100) if arg_mask.sum() == 0: arg_end_loss = torch.tensor(0.0, device=device) else: arg_end_loss = self.ce(arg_end_logits_flat, arg_end_labels_flat) # (B*K*M,) # ===== TOTAL ===== total_loss = ( self.lambda_trg_ce * (trg_start_loss + trg_end_loss) + self.lambda_arg_ce * (arg_start_loss + arg_end_loss) ) return { "total": total_loss, "trg_start_loss": trg_start_loss, "trg_end_loss": trg_end_loss, "arg_start_loss": arg_start_loss, "arg_end_loss": arg_end_loss, } # %% [code] ## Viết eval_fn vào đây # Bỏ hết eval_fn và trọng số vào đây class CustomEvalFn(nn.Module): def __init__(self): super().__init__() def compute_f1(self, tp, fp, fn): precision = tp / (tp + fp + 1e-8) recall = tp / (tp + fn + 1e-8) f1 = 2 * precision * recall / (precision + recall + 1e-8) return precision, recall, f1 def forward(self, pred, gold): pred_set = set(pred) gold_set = set(gold) tp = len(pred_set & gold_set) fp = len(pred_set - gold_set) fn = len(gold_set - pred_set) precision, recall, f1 = self.compute_f1(tp, fp, fn) return { f"precision": precision, f"recall": recall, f"f1": f1, } class SpanErrorAnalyzer: def __init__(self, pad_token_id=0): self.pad_token_id = pad_token_id # ===== helper ===== def _to_set(self, data): """ data: list of (b, tuple(ids)) -> dict[b] = set(tuple(ids)) """ res = defaultdict(set) for b, ids in data: ids = tuple([i for i in ids if i != self.pad_token_id]) if len(ids) > 0: res[b].add(ids) return res def _iou(self, a, b): """ a, b: tuple(ids) """ set_a, set_b = set(a), set(b) inter = len(set_a & set_b) union = len(set_a | set_b) if union == 0: return 0.0 return inter / union def _boundary_error(self, pred, gold): """ đo lệch boundary dựa trên overlap prefix/suffix """ # left match left = 0 for i in range(min(len(pred), len(gold))): if pred[i] == gold[i]: left += 1 else: break # right match right = 0 for i in range(1, min(len(pred), len(gold)) + 1): if pred[-i] == gold[-i]: right += 1 else: break return { "left_match": left, "right_match": right, "pred_len": len(pred), "gold_len": len(gold), } # ===== main ===== def analyze(self, preds, golds): pred_map = self._to_set(preds) gold_map = self._to_set(golds) all_batches = set(pred_map.keys()) | set(gold_map.keys()) stats = Counter() detailed_errors = [] for b in all_batches: pset = pred_map.get(b, set()) gset = gold_map.get(b, set()) matched_gold = set() # ===== check predictions ===== for p in pset: if p in gset: stats["exact_match"] += 1 matched_gold.add(p) else: # tìm gold gần nhất best_iou = 0 best_g = None for g in gset: iou = self._iou(p, g) if iou > best_iou: best_iou = iou best_g = g if best_iou > 0: stats["partial_match"] += 1 boundary = self._boundary_error(p, best_g) detailed_errors.append({ "type": "boundary_error", "batch": b, "pred": p, "gold": best_g, "iou": best_iou, **boundary }) else: if b not in gold_map: stats["no_event_sample"] += 1 err_type = "no_event_sample" else: stats["completely_wrong"] += 1 err_type = "completely_wrong" detailed_errors.append({ "type": err_type, "batch": b, "pred": p }) # ===== check missing ===== for g in gset: if g not in matched_gold: # check if any pred overlaps overlap = any(self._iou(p, g) > 0 for p in pset) if overlap: stats["miss_with_overlap"] += 1 else: stats["miss"] += 1 detailed_errors.append({ "type": "miss", "batch": b, "gold": g }) return { "summary": { "exact_match": (stats["exact_match"], stats["exact_match"] / len(preds)), "partial_match": (stats["partial_match"], stats["partial_match"] / len(preds)), "no_event_sample": (stats["no_event_sample"], stats["no_event_sample"] / len(preds)), "completely_wrong": (stats["completely_wrong"], stats["completely_wrong"] / len(preds)), "miss": (stats["miss"], stats["miss"] / len(golds)), "miss_with_overlap": (stats["miss_with_overlap"], stats["miss_with_overlap"] / len(golds)), }, "details": detailed_errors } # %% [code] ## Viết cấu trúc model vào đây def fix_bio_ids_batch(label_ids): """ label_ids: (B, L) return: (B, L) fixed """ B, L = label_ids.shape fixed = label_ids.clone() for b in range(B): for i in range(L): tag = fixed[b, i].item() if tag == 0: continue # I- (even) if tag % 2 == 0: if i == 0 or fixed[b, i-1].item() == 0: fixed[b, i] = tag - 1 else: prev_tag = fixed[b, i-1].item() if prev_tag == 0: fixed[b, i] = tag - 1 else: prev_type = (prev_tag - 1) // 2 curr_type = (tag - 1) // 2 if prev_type != curr_type: fixed[b, i] = tag - 1 return fixed def extract_trigger_spans_batch_tensor(label_ids): """ label_ids: (B, L) return: spans_tensor: (B, N, 2) # (s, e), pad = (0,0) """ B, L = label_ids.shape device = label_ids.device all_spans = [] max_n = 0 # ===== extract spans (list trước) ===== for b in range(B): spans = [] i = 0 while i < L: tag = label_ids[b, i].item() if tag == 0: i += 1 continue # B- (odd) if tag % 2 == 1: type_id = (tag - 1) // 2 s = i e = i i += 1 while i < L: next_tag = label_ids[b, i].item() if next_tag == 0: break next_type = (next_tag - 1) // 2 if next_tag % 2 == 0 and next_type == type_id: e = i i += 1 else: break spans.append((s, e)) else: i += 1 all_spans.append(spans) max_n = max(max_n, len(spans)) # ===== build tensor ===== if max_n == 0: # không có span nào → return tensor rỗng đúng shape return torch.zeros((B, 0, 2), dtype=torch.long, device=device) spans_tensor = torch.zeros((B, max_n, 2), dtype=torch.long, device=device) for b in range(B): for i, (s, e) in enumerate(all_spans[b]): spans_tensor[b, i, 0] = s spans_tensor[b, i, 1] = e return spans_tensor def get_span_repr(hidden, spans): B, L, H = hidden.size() K = spans.size(1) device = hidden.device start = spans[:, :, 0] # (B, K) end = spans[:, :, 1] # (B, K) h_s = torch.gather(hidden, 1, start.unsqueeze(-1).expand(-1, -1, H)) h_e = torch.gather(hidden, 1, end.unsqueeze(-1).expand(-1, -1, H)) h_diff = h_s - h_e h_prod = h_s * h_e # ===== 6. concat ===== span_repr = torch.cat( [h_s, h_e, h_diff, h_prod], dim=-1 ) return span_repr class MLP(nn.Module): def __init__(self, in_size, hid_size, out_size): super().__init__() self.model = nn.Sequential( nn.Linear(in_size, hid_size), nn.ReLU(), nn.Linear(hid_size, out_size) ) def forward(self, x): return self.model(x) class IEModel(nn.Module): def __init__(self, backbone_model_name, num_trg_labels, num_arg_labels): super().__init__() self.encoder = AutoModel.from_pretrained(backbone_model_name) hidden_size = self.encoder.config.hidden_size self.trg_start_classifier = MLP(hidden_size, hidden_size, num_trg_labels) self.trg_end_classifier = MLP(hidden_size, hidden_size, num_trg_labels) self.trg_repr_proj = MLP(hidden_size*4, hidden_size, hidden_size) self.arg_start_classifier = MLP(hidden_size*2, hidden_size, num_arg_labels) self.arg_end_classifier = MLP(hidden_size*2, hidden_size, num_arg_labels) def encode(self, input_ids, attention_mask): B, n_parts, L = input_ids.shape input_ids = input_ids.view(-1, L) attention_mask = attention_mask.view(-1, L) outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) hidden_states = outputs.last_hidden_state # B * n_parts, L, H hidden_states = hidden_states.view(B, n_parts, L, -1).reshape(B, n_parts*L, -1) # B, L, H return hidden_states def get_trg_logits(self, hidden_states): trg_start_logits = self.trg_start_classifier(hidden_states) # B, N, trg_classes trg_end_logits = self.trg_end_classifier(hidden_states) # B, N, trg_classes return trg_start_logits, trg_end_logits def get_arg_logits(self, hidden_states, trg_repr): B, L, H = hidden_states.shape _, N, _ = trg_repr.shape hidden_expand = hidden_states.unsqueeze(1).expand(-1, N, -1, -1) trg_expand = trg_repr.unsqueeze(2).expand(-1, -1, L, -1) hidden_trg_repr = torch.cat([hidden_expand, trg_expand], dim=-1) # (B, N, L, 2H) arg_start_logits = self.arg_start_classifier(hidden_trg_repr) # (B, N, L, C) arg_end_logits = self.arg_end_classifier(hidden_trg_repr) # (B, N, L, C) return arg_start_logits, arg_end_logits def forward(self, input_ids, attention_mask, trg_spans=None): hidden_states = self.encode(input_ids, attention_mask) trg_start_logits, trg_end_logits = self.get_trg_logits(hidden_states) if trg_spans is None: trg_labels = torch.argmax(trg_logits, dim=-1) trg_labels = fix_bio_ids_batch(trg_labels) trg_spans = extract_trigger_spans_batch_tensor(trg_labels) trg_repr = get_span_repr(hidden_states, trg_spans) # B, N, 4H trg_repr = self.trg_repr_proj(trg_repr) # B, N, H arg_start_logits, arg_end_logits = self.get_arg_logits(hidden_states, trg_repr) return trg_start_logits, trg_end_logits, arg_start_logits, arg_end_logits, trg_spans def test(): model = nn.DataParallel(IEModel(backbone_model_name, 7, 5)).to(device) model.eval() total_params = sum(p.numel() for p in model.parameters()) print(f"Total params: {total_params:,}") vocab_size = model.module.encoder.config.vocab_size max_len = model.module.encoder.config.max_position_embeddings bz = 32 i = torch.randint(0, vocab_size, (bz, 5, 10)).to(device) a = torch.ones(bz, 5, 10).to(device) g = torch.ones(bz, 3, 2, dtype=torch.long).to(device) with torch.no_grad(): r = model(i, a, g) if type(r) == tuple: print([r[i].shape for i in range(len(r))]) else: print(r.shape) test() # %% [code] def configure_optimizers(network, optim_params, scheduler_params): try: optim_params = copy.copy(optim_params) scheduler_params = copy.copy(scheduler_params) optim_name = optim_params.pop('name') scheduler_name = scheduler_params.pop('name') optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None) scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None) if optimizer_cls is None: raise ValueError(f"Optimizer '{optim_name}' is not available!") optimizer = optimizer_cls(network.parameters(), **optim_params) scheduler = None if scheduler_params and scheduler_cls: # Chỉ tạo scheduler nếu có tham số scheduler = scheduler_cls(optimizer, **scheduler_params) return optimizer, scheduler except KeyError as e: raise ValueError(f"Missing {e} in config!!") def freeze(self, model): model.eval() for param in model.parameters(): param.requires_grad = False def unfreeze(self, model): model.train() for param in model.parameters(): param.requires_grad = True def reduce_batch_size(loader, ratio=0.5): new_bs = max(1, int(loader.batch_size * ratio)) shuffle = isinstance(loader.sampler, RandomSampler) new_loader = DataLoader( dataset=loader.dataset, batch_size=new_bs, shuffle=shuffle, sampler=None if shuffle else loader.sampler, num_workers=loader.num_workers, collate_fn=loader.collate_fn, pin_memory=loader.pin_memory, drop_last=loader.drop_last, timeout=loader.timeout, worker_init_fn=loader.worker_init_fn, multiprocessing_context=loader.multiprocessing_context, generator=loader.generator, prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None, persistent_workers=loader.persistent_workers, pin_memory_device=loader.pin_memory_device ) return new_loader def list_to_tuple(x): if isinstance(x, (list, tuple)): return tuple(list_to_tuple(i) for i in x) return x def fmt(x): if isinstance(x, float): return round(x, 5) if isinstance(x, dict): return {k: fmt(v) for k, v in x.items()} if isinstance(x, list): return [fmt(v) for v in x] return x class ModelEmaV3Proxy(ModelEmaV3): def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError: return getattr(self.module, name) class DataParallelProxy(nn.DataParallel): def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError: attr = getattr(self.module, name) if callable(attr): def wrapper(*args, **kwargs): return self._parallel_apply_method(name, *args, **kwargs) return wrapper return attr def _parallel_apply_method(self, method_name, *inputs, **kwargs): if not self.device_ids: return getattr(self.module, method_name)(*inputs, **kwargs) inputs_scattered, kwargs_scattered = self.scatter(inputs, kwargs, self.device_ids) replicas = self.replicate(self.module, self.device_ids) outputs = self.parallel_apply( [getattr(replica, method_name) for replica in replicas], inputs_scattered, kwargs_scattered ) return self.gather(outputs, self.output_device) def map_arg_labels(all_arg_labels, trg_spans, pred_spans): """ all_arg_labels: (B, N, L) trg_spans: (B, N, 2) pred_spans: (B, M, 2) return: pred_arg_labels: (B, M, L) """ B, N, L = all_arg_labels.shape _, M, _ = pred_spans.shape device = all_arg_labels.device # ===== match (B, M, N) ===== match = ( (pred_spans.unsqueeze(2) == trg_spans.unsqueeze(1)) .all(dim=-1) ) # ===== index match ===== match_idx = match.float().argmax(dim=2) # (B, M) has_match = match.any(dim=2) # (B, M) # ===== gather ===== gather_idx = match_idx.unsqueeze(-1).expand(-1, -1, L) # (B, M, L) gathered = torch.gather( all_arg_labels, dim=1, index=gather_idx ) # (B, M, L) # ===== build output ===== # base = 0 nhưng giữ -100 base = torch.zeros((B, M, L), dtype=torch.long, device=device) # mask vị trí -100 từ source (lấy từ n=0 cũng được vì mask thường giống nhau) ignore_mask = (all_arg_labels[:, 0] == -100).unsqueeze(1).expand(-1, M, -1) base[ignore_mask] = -100 # ===== fill match ===== pred_arg_labels = torch.where( has_match.unsqueeze(-1), # (B, M, 1) gathered, base ) return pred_arg_labels.long() def decode_spans(start_labels, end_labels): """ start_labels/end_labels: (L,) return: [(s, e, label_id)] """ L = len(start_labels) used_start = set() used_end = set() spans = [] for s in range(L): s_label = start_labels[s] if s_label == 0: continue if s in used_start: continue nearest_e = None for e in range(s, L): if e in used_end: continue e_label = end_labels[e] if e_label == s_label: nearest_e = e break if nearest_e is None: continue used_start.add(s) used_end.add(nearest_e) spans.append((s, nearest_e, s_label)) return spans def decode_spans_batch(start_labels, end_labels): """ Args: start_labels: (B, L) end_labels: (B, L) Returns: spans_tensor: (B, N, 2) N = số span lớn nhất trong batch padding = (0, 0) """ B, L = start_labels.shape all_spans = [] max_n = 0 for bidx in range(B): used_start = set() used_end = set() spans = [] for s in range(L): s_label = start_labels[bidx, s].item() if s_label == 0: continue if s in used_start: continue nearest_e = None for e in range(s, L): if e in used_end: continue e_label = end_labels[bidx, e].item() if e_label == s_label: nearest_e = e break if nearest_e is None: continue used_start.add(s) used_end.add(nearest_e) spans.append((s, nearest_e)) all_spans.append(spans) max_n = max(max_n, len(spans)) # ===== padding ===== spans_tensor = torch.zeros( (B, max_n, 2), dtype=torch.long, device=start_labels.device ) for bidx, spans in enumerate(all_spans): for n, (s, e) in enumerate(spans): spans_tensor[bidx, n, 0] = s spans_tensor[bidx, n, 1] = e return spans_tensor def extract_arguments( input_ids, trg_start_logits, trg_end_logits, arg_start_logits, arg_end_logits, pred_trg_spans, id2label ): """ input_ids: (B, L) trg_start_logits: (B, L, C_trg) trg_end_logits: (B, L, C_trg) arg_start_logits: (B, N, L, C_arg) arg_end_logits: (B, N, L, C_arg) pred_trg_spans: (B, N, 2) id2label = { 'Trg': {id: label}, 'Arg': {id: label} } """ B, L = input_ids.shape # ===== decode trigger ===== trg_start_ids = torch.argmax(trg_start_logits, dim=-1) # (B, L) trg_end_ids = torch.argmax(trg_end_logits, dim=-1) # (B, L) # ===== extract trigger spans ===== trg_spans = [] for bidx in range(B): spans = decode_spans( trg_start_ids[bidx].tolist(), trg_end_ids[bidx].tolist() ) trg_spans.append(spans) results = [] for bidx in range(B): # map span -> label span2label = { (s, e): id2label['Trg'][t_id] for (s, e, t_id) in trg_spans[bidx] } for n in range(pred_trg_spans.shape[1]): s_trg = pred_trg_spans[bidx, n, 0].item() e_trg = pred_trg_spans[bidx, n, 1].item() # skip padding if s_trg == 0 and e_trg == 0: continue if (s_trg, e_trg) not in span2label: continue trg_label = span2label[(s_trg, e_trg)] trg_tokens = input_ids[ bidx, s_trg:e_trg + 1 ].tolist() # ===== argument ===== arg_start_ids = torch.argmax( arg_start_logits[bidx, n], dim=-1 ).tolist() arg_end_ids = torch.argmax( arg_end_logits[bidx, n], dim=-1 ).tolist() arg_spans = decode_spans( arg_start_ids, arg_end_ids ) for s_arg, e_arg, arg_label_id in arg_spans: arg_label = id2label['Arg'][arg_label_id] arg_tokens = input_ids[ bidx, s_arg:e_arg + 1 ].tolist() results.append(( bidx, (tuple(trg_tokens), trg_label), (tuple(arg_tokens), arg_label) )) return results class Trainer: def __init__( self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0, logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu', schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True ): self.ema_net = None self.training_time = self._time_str_to_seconds(training_time) self.mode = eval_mode self.topk = topk self.device = device self.logging = logging if logging < epochs else 1 self.logging_file = logging_file self.checkpoints_dir = checkpoints_dir self.early_stopping = early_stopping self.eval_from_ratio = eval_from_ratio self.eval_every = eval_every self.save_name = save_name self.save_best = save_best self.save_last = save_last self.return_best = return_best self.return_last = return_last self.max_grad_norm = max_grad_norm self.schedule_in_step = schedule_in_step self.use_ema = use_ema self.ema_from_ratio = ema_from_ratio self.ema_decay = ema_decay self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]] self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0) def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, id2label=None): if eval_fn is None: if self.mode == "max": eval_fn = lambda *x: -loss_fn(*x) else: eval_fn = lambda *x: loss_fn(*x) if torch.cuda.device_count() > 1: network = DataParallelProxy(network) network = network.to(self.device) if not start_training_time: start_training_time = time.time() start_ema = int(epochs * self.ema_from_ratio) start_eval = int(epochs * self.eval_from_ratio) if val_loader is None: print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!') else: model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc' start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else '' print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str) training_log = {} for epoch in range(start_epoch, epochs+start_epoch): if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema: self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device) try: teaching_rate = math.cos(math.pi / 2 * epoch / epochs) train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, teaching_rate) logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch} logging_dict.update(train_loss_epoch_dict) if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0: eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, id2label) update = self._update_best_network(eval_net, val_score, epoch) logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update}) logging_dict.update(val_score_dict) if not self.schedule_in_step and scheduler: scheduler.step() except RuntimeError as e: if "out of memory" in str(e).lower(): print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...") torch.cuda.empty_cache() gc.collect() if torch.cuda.is_available(): torch.cuda.synchronize() print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared") train_loader = reduce_batch_size(train_loader, ratio=0.5) if val_loader is not None: val_loader = reduce_batch_size(val_loader, ratio=0.5) logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')} else: raise training_log[epoch] = logging_dict if self.is_early_stopping(epoch): print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...') break if self.logging: if epoch % self.logging == 0: print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict)) else: print(f'{epoch}...', end=' ') if self._at_time_limit(start_training_time): print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}') break if self.logging_file: os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True) with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f: f.write(json.dumps(training_log)) if self.use_ema and self.ema_net is not None: self._save_state_dict(self.ema_net.module) else: self._save_state_dict(network) print(f'[Trainer CallBack] 📢 Kết thúc training.\n') best_model, last_model = None, None eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network if self.return_best : best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict() best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()} if self.return_last: last_model = eval_net.state_dict() last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()} del network torch.cuda.empty_cache() gc.collect() return training_log, best_model, last_model def _time_str_to_seconds(self, time_str): days, hours, minutes, seconds = map(int, time_str.split(":")) return days * 86400 + hours * 3600 + minutes * 60 + seconds def _update_best_network(self, network, val_score, epoch): topk = max(1, self.topk) self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}]) self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk] if val_score in [x[0] for x in self.best_stage]: return True return False def is_early_stopping(self, epoch): if self.best_stage[0][1] is None: return False if not self.early_stopping: return False return epoch - self.best_stage[0][1] >= self.early_stopping def _at_time_limit(self, start_training_time): return time.time() - start_training_time >= self.training_time def _save_state_dict(self, network): if self.topk <= 0: return if self.save_best: for r in range(self.topk): os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True) for rank, (score, epoch, state_dict) in enumerate(self.best_stage): if state_dict is None: continue state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()} torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth') if self.save_last: os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True) state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()} torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth') def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn, teaching_rate): network.train() total_loss = 0 total_loss_dict = {} for batch_idx, batch in enumerate(train_loader): optimizer.zero_grad() with torch.autocast(device_type=self.device, dtype=torch.float16): loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn, teaching_rate) for k, v in loss_dict.items(): t = total_loss_dict.get(k, 0) total_loss_dict[k] = t + v self.grad_scaler.scale(loss).backward() self.grad_scaler.unscale_(optimizer) grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm) # print(grad_norm) # Bỏ cmt dòng này để biết nên chọn max_grad_norm bằng bao nhiêu... self.grad_scaler.step(optimizer) self.grad_scaler.update() if self.schedule_in_step and scheduler: scheduler.step() if self.use_ema and self.ema_net is not None: self.ema_net.update(network) total_loss += loss return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()} def _eval_epoch(self, network, val_loader, eval_fn, id2label): network.eval() total_score = 0.0 total_score_dict = {} object_lists = None # sẽ init sau with torch.no_grad(): for batch_idx, batch in enumerate(val_loader): score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, id2label) total_score += score for k, v in score_dict.items(): t = total_score_dict.get(k, 0) total_score_dict[k] = t + v if objects: if object_lists is None: object_lists = [[] for _ in range(len(objects))] for i, obj in enumerate(objects): object_lists[i].append(obj.detach()) if object_lists is not None: object_arrays = [ torch.concat(obj_list, dim=0).cpu().numpy() for obj_list in object_lists ] else: object_arrays = [] return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays def _cal_loss(self, network, batch, batch_idx, loss_fn, teaching_rate): # Bạn cần override _cal_loss để tính loss input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) trg_spans = batch['trg_spans'].to(self.device) # B, M, 2 trg_start_labels = batch['trg_start_labels'].to(self.device) # B, L trg_end_labels = batch['trg_end_labels'].to(self.device) # B, L all_arg_start_labels = batch['all_arg_start_labels'].to(self.device) # B, M, L all_arg_end_labels = batch['all_arg_end_labels'].to(self.device) # B, M, L hidden_states = network.encode(input_ids, attention_mask) trg_start_logits, trg_end_logits = network.get_trg_logits(hidden_states) choice = random.random() if choice < teaching_rate: pred_trg_spans = trg_spans else: trg_start_ids = torch.argmax(trg_start_logits, dim=-1) # (B, L) trg_end_ids = torch.argmax(trg_end_logits, dim=-1) # (B, L) pred_trg_spans = decode_spans_batch(trg_start_ids, trg_end_ids) trg_repr = get_span_repr(hidden_states, pred_trg_spans) # B, N, 4H trg_repr = network.trg_repr_proj(trg_repr) # B, N, H arg_start_logits, arg_end_logits = network.get_arg_logits(hidden_states, trg_repr) pred_arg_start_labels = map_arg_labels(all_arg_start_labels, trg_spans, pred_trg_spans) pred_arg_end_labels = map_arg_labels(all_arg_end_labels, trg_spans, pred_trg_spans) loss_dict = loss_fn( trg_start_logits, trg_start_labels, trg_end_logits, trg_end_labels, arg_start_logits, pred_arg_start_labels, arg_end_logits, pred_arg_end_labels, ) return loss_dict['total'], loss_dict def _cal_val_score(self, network, batch, batch_idx, eval_fn, id2label): # Bạn cần override _cal_val_score để tính val score, list bên cạnh là để trả về y hay pred gì đó (nếu cần) input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) gold_events = batch['gold_events'] B, _, _ = input_ids.shape hidden_states = network.encode(input_ids, attention_mask) trg_start_logits, trg_end_logits = network.get_trg_logits(hidden_states) trg_start_ids = torch.argmax(trg_start_logits, dim=-1) # (B, L) trg_end_ids = torch.argmax(trg_end_logits, dim=-1) # (B, L) pred_trg_spans = decode_spans_batch(trg_start_ids, trg_end_ids) trg_repr = get_span_repr(hidden_states, pred_trg_spans) # B, N, 4H trg_repr = network.trg_repr_proj(trg_repr) # B, N, H arg_start_logits, arg_end_logits = network.get_arg_logits(hidden_states, trg_repr) pred_ids = extract_arguments(input_ids.reshape(B, -1), trg_start_logits, trg_end_logits, arg_start_logits, arg_end_logits, pred_trg_spans, id2label) pred_ids = list_to_tuple(pred_ids) gold_ids = list_to_tuple(gold_events) score_dict = eval_fn(pred_ids, gold_ids) return score_dict['f1'], score_dict, [] # %% [code] class PhoBERTSpanAligner: def __init__(self, tokenizer, max_len): self.tokenizer = tokenizer self.max_len = max_len # ===== 1. Extract discontinuous spans ===== def extract_spans(self, sample): trigger_spans, arg_spans = [], [] for event in sample["events"]: trigger_type = event["label"] spans = [tuple(event["offset"])] trigger_spans.append({ "spans": spans, "label": trigger_type }) event_arg_spans = [] for arg in event['arguments']: arg_type = arg["role"] spans = [tuple(arg["offset"])] event_arg_spans.append({ "spans": spans, "label": arg_type }) arg_spans.append(event_arg_spans) return trigger_spans, arg_spans # ===== 2. Word offsets ===== def build_word_offsets(self, text, words): offsets = [] pointer = 0 for word in words: start = text.find(word, pointer) end = start + len(word) offsets.append((start, end)) pointer = end return offsets # ===== 3. Char → word ===== def char_span_to_word_span(self, word_offsets, start, end): start_word = None end_word = None for i, (w_start, w_end) in enumerate(word_offsets): if w_start <= start < w_end: start_word = i if w_start < end <= w_end: end_word = i return start_word, end_word # ===== 4. Word → subword ===== def word_to_subword_map(self, words): mapping = [] subword_index = 1 # for word in words: sub_tokens = self.tokenizer.tokenize(word) start = subword_index end = subword_index + len(sub_tokens) - 1 mapping.append((start, end)) subword_index += len(sub_tokens) return mapping # ===== 5. Span → subword ===== def span_to_subword(self, word_offsets, word_subword_map, spans): sub_spans = [] for span_start, span_end in spans: w_start, w_end = self.char_span_to_word_span( word_offsets, span_start, span_end ) if w_start is None or w_end is None: continue sub_start = word_subword_map[w_start][0] sub_end = word_subword_map[w_end][1] sub_spans.append((sub_start, sub_end)) return sub_spans def extract_valid_spans(self, sub_spans): valid_spans = [] for s, e in sub_spans: if s < 0 or e < 0 or s >= self.max_len or e >= self.max_len or s > e: continue valid_spans.append((s, e)) return valid_spans def encode(self, sample): text = sample["text"] triggers, arguments = self.extract_spans(sample) # ===== 1. Word tokenize ===== words = word_tokenize(text) sentence = " ".join(words) # ===== 2. Mapping ===== word_offsets = self.build_word_offsets(text, words) word_subword_map = self.word_to_subword_map(words) # ===== 3. Tokenize FULL ===== encoding = self.tokenizer( sentence, max_length=self.max_len, truncation=True, padding="max_length", return_tensors="pt" ) input_ids = encoding["input_ids"][0] attention_mask = encoding["attention_mask"][0] # ===== 5. Convert spans ===== triggers_gold_spans = [] arguments_gold_spans = [] for trg, args in zip(triggers, arguments): label = trg["label"] sub_spans = self.span_to_subword( word_offsets, word_subword_map, trg["spans"] ) valid_spans = self.extract_valid_spans(sub_spans) if len(valid_spans) == 0: continue triggers_gold_spans.append((tuple(valid_spans), label)) trg_args_gold_spans = [] for arg in args: label = arg["label"] sub_spans = self.span_to_subword( word_offsets, word_subword_map, arg["spans"] ) valid_spans = self.extract_valid_spans(sub_spans) if len(valid_spans) == 0: continue trg_args_gold_spans.append((tuple(valid_spans), label)) arguments_gold_spans.append(tuple(trg_args_gold_spans)) return { "input_ids": input_ids, "attention_mask": attention_mask, "triggers_gold_spans": triggers_gold_spans, "arguments_gold_spans": arguments_gold_spans, } def generate_candidate_spans(seq_len, max_span_len): spans = [] for i in range(1, seq_len+1): for j in range(i, min(i+max_span_len, seq_len+1)): spans.append((i, j)) return spans class KLTNDataset(Dataset): def __init__(self, all_data, using_idxes, label2id, tokenizer, max_len, max_n_parts): super().__init__() self.tokenizer = tokenizer self.aligner = PhoBERTSpanAligner(tokenizer, max_len*max_n_parts) self.all_data = all_data self.using_idxes = using_idxes self.label2id = label2id self.max_len = max_len self.max_n_parts = max_n_parts def __len__(self): return len(self.using_idxes) def __getitem__(self, idx): ridx = self.using_idxes[idx] sample = self.all_data[ridx] result = self.aligner.encode(sample) input_ids = result["input_ids"].squeeze(0) attention_mask = result["attention_mask"].squeeze(0) triggers_gold_spans = result["triggers_gold_spans"] arguments_gold_spans = result["arguments_gold_spans"] # Get event label all_trg_spans = torch.tensor([list(trg_spans[0]) for trg_spans, _ in triggers_gold_spans], dtype=torch.long) if triggers_gold_spans else torch.empty(0, 2, dtype=torch.long) gold_events = [] trg_start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100) trg_end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100) all_arg_start_labels, all_arg_end_labels = [], [] for (trg_spans, trg_label), args in zip(triggers_gold_spans, arguments_gold_spans): s, e = trg_spans[0] trg_start_labels[s] = self.label2id['Trg'][f'{trg_label}'] trg_end_labels[e] = self.label2id['Trg'][f'{trg_label}'] event = [(tuple(input_ids[s:e+1].tolist()), trg_label)] arg_start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100) arg_end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100) for arg_spans, arg_label in args: s, e = arg_spans[0] arg_start_labels[s] = self.label2id['Arg'][f'{arg_label}'] arg_end_labels[e] = self.label2id['Arg'][f'{arg_label}'] event.append((tuple(input_ids[s:e+1].tolist()), arg_label)) all_arg_start_labels.append(arg_start_labels) all_arg_end_labels.append(arg_end_labels) gold_events.append(event) input_ids = input_ids.reshape(self.max_n_parts, self.max_len) attention_mask = attention_mask.reshape(self.max_n_parts, self.max_len) n_valid_parts = math.ceil(attention_mask.sum().item() / self.max_len) input_ids = input_ids[:n_valid_parts] attention_mask = attention_mask[:n_valid_parts] trg_start_labels = trg_start_labels[:n_valid_parts*self.max_len] trg_end_labels = trg_end_labels[:n_valid_parts*self.max_len] all_arg_start_labels = torch.stack([arg_labels[:n_valid_parts*self.max_len] for arg_labels in all_arg_start_labels], dim=0) if all_arg_start_labels else torch.empty(0, n_valid_parts*self.max_len) all_arg_end_labels = torch.stack([arg_labels[:n_valid_parts*self.max_len] for arg_labels in all_arg_end_labels], dim=0) if all_arg_end_labels else torch.empty(0, n_valid_parts*self.max_len) return { "input_ids": input_ids, "attention_mask": attention_mask, "trg_spans": all_trg_spans, "trg_start_labels": trg_start_labels, "trg_end_labels": trg_end_labels, "all_arg_start_labels": all_arg_start_labels, "all_arg_end_labels": all_arg_end_labels, "gold_events": gold_events, } def _pad_batch(tensor_list, pad_value=0): """ tensor_list: list of tensors mỗi tensor shape: (Nk, n_parts_i, max_len_i) return: padded tensor shape: (B, max_Nk, max_n_parts, max_len) """ # lấy max toàn batch max_Nk = max(t.size(0) for t in tensor_list) max_n_parts = max(t.size(1) for t in tensor_list) max_len = max(t.size(2) for t in tensor_list) padded = [] for t in tensor_list: Nk, n_parts_i, max_len_i = t.shape # pad chiều n_parts và max_len trước if n_parts_i < max_n_parts or max_len_i < max_len: new_t = t.new_full( (Nk, max_n_parts, max_len), pad_value ) new_t[:, :n_parts_i, :max_len_i] = t t = new_t # pad chiều Nk if Nk < max_Nk: pad_tensor = t.new_full( (max_Nk - Nk, max_n_parts, max_len), pad_value ) t = torch.cat([t, pad_tensor], dim=0) padded.append(t) return torch.stack(padded) # (B, max_Nk, max_n_parts, max_len) def collate_fn(batch): gold_events = [] for bidx, b in enumerate(batch): for event in b['gold_events']: trg = event[0] if len(event) > 1: for arg in event[1:]: gold_events.append([bidx, trg, arg]) else: gold_events.append([bidx, trg, (tuple([]), 0)]) input_ids = [b["input_ids"].unsqueeze(-1) for b in batch] attention_mask = [b["attention_mask"].unsqueeze(-1) for b in batch] trg_spans = [b["trg_spans"].unsqueeze(-1) for b in batch] trg_start_labels = [b["trg_start_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch] trg_end_labels = [b["trg_end_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch] all_arg_start_labels = [b["all_arg_start_labels"].unsqueeze(-1) for b in batch] all_arg_end_labels = [b["all_arg_end_labels"].unsqueeze(-1) for b in batch] # pad theo Nk input_ids = _pad_batch(input_ids, pad_value=0).squeeze(-1) attention_mask = _pad_batch(attention_mask, pad_value=0).squeeze(-1) trg_spans = _pad_batch(trg_spans, pad_value=0).squeeze(-1) trg_start_labels = _pad_batch(trg_start_labels, pad_value=-100).squeeze(-1).squeeze(-1) trg_end_labels = _pad_batch(trg_end_labels, pad_value=-100).squeeze(-1).squeeze(-1) all_arg_start_labels = _pad_batch(all_arg_start_labels, pad_value=-100).squeeze(-1) all_arg_end_labels = _pad_batch(all_arg_end_labels, pad_value=-100).squeeze(-1) return { "input_ids": input_ids, "attention_mask": attention_mask, "trg_spans": trg_spans, "trg_start_labels": trg_start_labels, "trg_end_labels": trg_end_labels, "all_arg_start_labels": all_arg_start_labels, "all_arg_end_labels": all_arg_end_labels, "gold_events": gold_events, } # %% [code] def shift_bidx(spans, batch_idx): shifted = [] for bidx, trg, arg in spans: new_bidx = bidx + batch_idx * batch_size shifted.append((new_bidx, trg, arg)) return shifted def refactor_events(events, save_dict): trg_i, trg_c, arg_i, arg_c, soft, strict_dict = [], [], [], [], [], {} for bidx, (trg_ids, trg_lb), (arg_k_ids, arg_k_lb) in events: if (bidx, trg_ids) not in trg_i: trg_i.append((bidx, trg_ids)) if (bidx, (trg_ids, trg_lb)) not in trg_c: trg_c.append((bidx, (trg_ids, trg_lb))) if (bidx, trg_ids, arg_k_ids) not in arg_i: arg_i.append((bidx, trg_ids, arg_k_ids)) if (bidx, trg_ids, (arg_k_ids, arg_k_lb)) not in arg_c: arg_c.append((bidx, trg_ids, (arg_k_ids, arg_k_lb))) if (bidx, (trg_ids, trg_lb), (arg_k_ids, arg_k_lb)) not in soft: soft.append((bidx, (trg_ids, trg_lb), (arg_k_ids, arg_k_lb))) if bidx not in strict_dict: strict_dict[bidx] = {} if (trg_ids, trg_lb) not in strict_dict[bidx]: strict_dict[bidx][(trg_ids, trg_lb)] = [] strict_dict[bidx][(trg_ids, trg_lb)].append((arg_k_ids, arg_k_lb)) strict = [] for bidx, trg_dict in strict_dict.items(): for trg, args in trg_dict.items(): strict.append((bidx, trg, frozenset(args))) save_dict['Trg-I'].extend(trg_i) save_dict['Trg-C'].extend(trg_c) save_dict['Arg-I'].extend(arg_i) save_dict['Arg-C'].extend(arg_c) save_dict['Soft-Event'].extend(soft) save_dict['Strict-Event'].extend(strict) def test(network, state_dicts, test_loader, eval_fn, analyzer, device, id2label, tokenizer): if torch.cuda.device_count() > 1: network = DataParallelProxy(network) network = network.to(device) network.eval() eval_types = ['Trg-I', 'Trg-C', 'Arg-I', 'Arg-C', 'Soft-Event', 'Strict-Event'] all_pred = {eval_type: [] for eval_type in eval_types} all_gold = {eval_type: [] for eval_type in eval_types} list_input_ids = [] with torch.no_grad(): for batch_idx, batch in enumerate(test_loader): input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) gold_events = batch['gold_events'] B, _, _ = input_ids.shape list_input_ids.extend(input_ids.reshape(B, -1).tolist()) list_trg_start_logits = [] list_trg_end_logits = [] list_hidden_states = [] list_arg_start_logits = [] list_arg_end_logits = [] for sd in state_dicts: if torch.cuda.device_count() > 1: network.module.load_state_dict(sd) else: network.load_state_dict(sd) hidden_states = network.encode(input_ids, attention_mask) trg_start_logits, trg_end_logits = network.get_trg_logits(hidden_states) list_trg_start_logits.append(trg_start_logits) list_trg_end_logits.append(trg_end_logits) list_hidden_states.append(hidden_states) ensemble_trg_start_logits = torch.stack(list_trg_start_logits, dim=0).mean(dim=0) ensemble_trg_end_logits = torch.stack(list_trg_end_logits, dim=0).mean(dim=0) trg_start_ids = torch.argmax(ensemble_trg_start_logits, dim=-1) # (B, L) trg_end_ids = torch.argmax(ensemble_trg_end_logits, dim=-1) # (B, L) pred_trg_spans = decode_spans_batch(trg_start_ids, trg_end_ids) for sd, hidden_states in zip(state_dicts, list_hidden_states): if torch.cuda.device_count() > 1: network.module.load_state_dict(sd) else: network.load_state_dict(sd) trg_repr = get_span_repr(hidden_states, pred_trg_spans) # B, N, 4H trg_repr = network.trg_repr_proj(trg_repr) # B, N, H arg_start_logits, arg_end_logits = network.get_arg_logits(hidden_states, trg_repr) list_arg_start_logits.append(arg_start_logits) list_arg_end_logits.append(arg_end_logits) ensemble_arg_start_logits = torch.stack(list_arg_start_logits, dim=0).mean(dim=0) ensemble_arg_end_logits = torch.stack(list_arg_end_logits, dim=0).mean(dim=0) pred_events = extract_arguments( input_ids.reshape(B, -1), ensemble_trg_start_logits, ensemble_trg_end_logits, ensemble_arg_start_logits, ensemble_arg_end_logits, pred_trg_spans, id2label ) pred_events = shift_bidx(pred_events, batch_idx) refactor_events(pred_events, all_pred) gold_events = shift_bidx(gold_events, batch_idx) refactor_events(gold_events, all_gold) # ===== GLOBAL EVAL ===== final_score = {} for eval_type in eval_types: score = eval_fn(list_to_tuple(all_pred[eval_type]), list_to_tuple(all_gold[eval_type])) final_score[eval_type] = score analyze_result = analyzer.analyze(list_to_tuple(all_pred['Trg-I']), list_to_tuple(all_gold['Trg-I'])) # ===== PREDICT ===== predictions = [] for input_ids in list_input_ids: predictions.append([tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)]) for event in all_pred['Strict-Event']: bidx = event[0] trg = tokenizer.decode(event[1][0], skip_special_tokens=True, clean_up_tokenization_spaces=True) trg_lb = event[1][1] predictions[bidx].append((trg, trg_lb)) for arg_infor in event[2]: arg = tokenizer.decode(arg_infor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) arg_lb = arg_infor[1] predictions[bidx].append((arg, arg_lb)) return final_score, analyze_result, predictions # %% [code] with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f: data_train = json.load(f) with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f: data_test = json.load(f) print('Train:', len(data_train)) print('Test:', len(data_test)) # %% [code] trigger_types = ['O'] + sorted(list(set([e['label'] for d in data_train + data_test for e in d['events']]))) # NBR : Neighbor relation # bio_trigger_types = [f'{prefix}-{trg}' for trg in trigger_types for prefix in ['B', 'I']] trigger_label2id = {l: i for i, l in enumerate(trigger_types)} trigger_id2label = {i: l for l, i in trigger_label2id.items()} argument_types = ['O'] + sorted(list(set([a['role'] for d in data_train + data_test for e in d['events'] for a in e['arguments']]))) # bio_argument_types = [f'{prefix}-{arg}' for arg in argument_types for prefix in ['B', 'I']] argument_label2id = {l: i for i, l in enumerate(argument_types)} argument_id2label = {i: l for l, i in argument_label2id.items()} label2id = { 'Trg': trigger_label2id, 'Arg': argument_label2id, } id2label = { 'Trg': trigger_id2label, 'Arg': argument_id2label, } # %% [code] zero_events_idxes = [] for idx, d in enumerate(data_train): if len(d['events']) == 0: zero_events_idxes.append(idx) n_zero_events_samples = len(zero_events_idxes) n_has_events_samples = len(data_train) - n_zero_events_samples random.seed(42) k = min(int(n_has_events_samples * zero_events_rate), len(zero_events_idxes)) sampled_zero_events_idxes = random.sample(zero_events_idxes, k) new_data_train = [] for idx, d in enumerate(data_train): if len(d['events']) == 0: if idx in sampled_zero_events_idxes: new_data_train.append(d) else: new_data_train.append(d) data_train = new_data_train print('Train:', len(data_train)) # %% [code] if debug_only: data_train = data_train[:20] data_test = data_test[:20] print('Train:', len(data_train)) print('Test:', len(data_test)) # %% [code] tokenizer = AutoTokenizer.from_pretrained(backbone_model_name) # %% [code] print('Experiment name:', state_dict_save_name) # %% [code] if not test_only: full_idxes = np.array(range(len(data_train))) training_logs, best_models, last_models = [], [], [] start_training_time = time.time() for seed in SEEDS: kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed) for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)): if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx: continue set_seed(seed) train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx] trainset = KLTNDataset(data_train, train_idxes, label2id, tokenizer, **train_memory_params) valset = KLTNDataset(data_train, val_idxes, label2id, tokenizer, **val_memory_params) generator = torch.Generator() generator.manual_seed(seed) train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params) val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params) my_model = IEModel( num_trg_labels=len(trigger_label2id), num_arg_labels=len(argument_label2id), **model_params ) total_params = sum(p.numel() for p in my_model.parameters()) print(f"Total params: {total_params:,}") # optimizer, scheduler = configure_optimizers(my_model, optim_params, scheduler_params) encoder_params = set(map(id, my_model.encoder.parameters())) other_params = [ p for p in my_model.parameters() if id(p) not in encoder_params ] optimizer = optim.AdamW([ {"params": my_model.encoder.parameters(), "lr": 2e-5}, {"params": other_params} ], lr=5e-4) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6) loss_fn = CustomLoss( **loss_func_params ) eval_fn = CustomEvalFn(**eval_func_params) trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}' trainer = Trainer(**trainer_params) print(f'Start Training Fold {fold_idx}...') training_log, best_model, last_model = trainer.fit( my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, eval_fn, start_epoch=1, start_training_time=start_training_time, id2label=id2label ) training_logs.append(training_log) best_models.append(best_model) last_models.append(last_model) # %% [code] def load_all_state_dicts(folder): files = [] for file in os.listdir(folder): if file.endswith(".pt") or file.endswith(".pth"): m = re.search(r"f(\d+)", file) # tìm f if m: fold = int(m.group(1)) files.append((fold, file)) # sort theo fold files.sort(key=lambda x: x[0]) state_dicts = [] for fold, file in files: path = os.path.join(folder, file) print(f"Loading fold {fold}: {file}") state_dict = torch.load(path, map_location="cpu") state_dicts.append(state_dict) return state_dicts if test_only: snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"]) get_ipython().system('rm -rf .cache .gitattributes') best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s") last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts") # %% [code] os.makedirs(f'{checkpoints_dir}/results', exist_ok=True) testset = KLTNDataset(data_test, range(len(data_test)), label2id, tokenizer, **val_memory_params) generator = torch.Generator() test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params) eval_fn = CustomEvalFn(**eval_func_params) analyzer = SpanErrorAnalyzer() my_model = IEModel( num_trg_labels=len(trigger_label2id), num_arg_labels=len(argument_label2id), **model_params ) total_params = sum(p.numel() for p in my_model.parameters()) print(f"Total params: {total_params:,}") # %% [code] start_time = time.time() best_score, best_analyze_result, best_pred_test = test(my_model, best_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer) last_score, last_analyze_result, last_pred_test = test(my_model, last_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer) result_test = {"Best model": best_score, "Last model": last_score} analyze_result = {"Best model": best_analyze_result, "Last model": last_analyze_result} analyze_result_sumary = {"Best model": best_analyze_result['summary'], "Last model": last_analyze_result['summary']} pred_test = {"Best model": best_pred_test, "Last model": last_pred_test} with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f: json.dump(result_test, f, ensure_ascii=False, indent=2) with open(f"{checkpoints_dir}/results/{state_dict_save_name}_error_analyze_result.json", "w", encoding="utf-8") as f: json.dump(analyze_result, f, ensure_ascii=False, indent=2) with open(f"{checkpoints_dir}/results/{state_dict_save_name}_pred_test.json", "w", encoding="utf-8") as f: json.dump(pred_test, f, ensure_ascii=False, indent=2) print('Test:', time.time() - start_time, 's --> Done!') print(json.dumps(analyze_result_sumary, ensure_ascii=False, indent=4)) # %% [code] best_pred_test[:10] # %% [code] last_pred_test[:10] # %% [code] def dict_to_df(data): row_tuples = [] row_values = [] metrics = ["precision", "recall", "f1"] # Lấy model đầu tiên first_model = next(iter(data.values())) # eval_keys eval_keys = list(first_model.keys()) for eval_key in eval_keys: row_tuples.append(eval_key) row = {} for model_name, model_data in data.items(): for metric in metrics: row[(model_name, metric)] = model_data[eval_key][metric] row_values.append(row) # ===== DataFrame ===== df = pd.DataFrame(row_values) # MultiIndex columns df.columns = pd.MultiIndex.from_tuples(df.columns) # Index df.index = pd.Index(row_tuples, name="evaluation") # ===== Sort ===== sort_keys = [] if ("Best model", "f1") in df.columns: sort_keys.append(("Best model", "f1")) if ("Last model", "f1") in df.columns: sort_keys.append(("Last model", "f1")) if sort_keys: df = df.sort_values(by=sort_keys, ascending=False) return df result_test_df = dict_to_df(result_test) result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx") result_test_df # %% [code] key = ("Best model", "f1") result_test_df_best = result_test_df.sort_values(by=key, ascending=False).groupby(level="evaluation").head(1) result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx") result_test_df_best # %% [code] def get_avg_best_score(logs): return float(np.mean([list(log.values())[-1]['best_score'] for log in logs])) def get_avg_log(logs, epochs): avg_log = {} for epoch in range(1, epochs + 1): val_score = 0.0 train_loss = 0.0 n_eval = 0 for idx in range(len(logs)): log = logs[idx].get(epoch, logs[idx].get(str(epoch))) if log is None: continue val_score += log.get('val_score', 0.0) train_loss += log.get('train_loss', 0.0) n_eval += 1 if n_eval == 0: continue avg_log[epoch] = { 'train_loss': train_loss / n_eval, 'val_score': val_score / n_eval if val_score != 0 else float('inf') } return avg_log def parse_label_key(label: str): try: first = float(label.split('_', 1)[0]) # số đầu: trước dấu _ last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0]) return first, last except: return (0, 0) def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)): fig, axes = plt.subplots(1, 2, figsize=figsize) # ===== Plot Train Loss ===== for name, log in logs_dict.items(): epochs = sorted(log.keys()) train_loss = [log[e]['train_loss'] for e in epochs] axes[0].plot(epochs, train_loss, label=name) axes[0].set_xlabel('Epoch') axes[0].set_ylabel('Train Loss') axes[0].set_title('Training Loss') axes[0].grid(True) # ===== Plot Validation Score ===== for name, log in logs_dict.items(): epochs = sorted(log.keys()) val_score = [log[e]['val_score'] for e in epochs] axes[1].plot(epochs, val_score, label=name) axes[1].set_xlabel('Epoch') axes[1].set_ylabel('Validation Score') axes[1].set_title('Validation Score') axes[1].grid(True) # ===== Shared Legend ===== handles, labels = axes[0].get_legend_handles_labels() pairs = list(zip(handles, labels)) pairs_sorted = sorted( pairs, key=lambda x: parse_label_key(x[1]) ) handles_sorted, labels_sorted = zip(*pairs_sorted) axes[0].legend( handles_sorted, labels_sorted, loc='center left', bbox_to_anchor=(1.01, 0.5), borderaxespad=0. ) plt.tight_layout(rect=[0, 0, 1, 1]) if save_path is not None: os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.show() # %% [code] if not test_only: snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*.json"]) get_ipython().system('rm -rf .cache .gitattributes') # %% [code] if not test_only: experiments = {} for experiment in os.listdir(pretrained_dir): if '.virtual_documents' in experiment: continue experiment_logs = [] try: for seed in SEEDS: for fold_idx in range(nfolds): with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f: experiment_log = json.load(f) experiment_logs.append(experiment_log) except: pass experiments[experiment] = get_avg_log(experiment_logs, 1000) experiments[state_dict_save_name] = get_avg_log(training_logs, 1000) # %% [code] if not test_only: score = get_avg_best_score(training_logs) state_dict_save_name, score # %% [code] if not test_only: plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5))