| ####################################################################################################### | |
| #原始版本 | |
| # import datetime | |
| # import logging | |
| # import json | |
| # import random | |
| # import time | |
| # import numpy as np | |
| # import os | |
| # import pickle | |
| # import sys | |
| # import torch | |
| # import torch.distributed as dist | |
| # import torch.nn.functional as F | |
| # import yaml | |
| # import transformers | |
| # import math | |
| # from torch.utils.data import DataLoader | |
| # from tqdm import tqdm | |
| # from transformers import HfArgumentParser, AutoConfig, AutoTokenizer | |
| # from datasets import Dataset, concatenate_datasets | |
| # from datasets.distributed import split_dataset_by_node | |
| # from src.arguments import ModelArguments, DataArguments, TrainingArguments | |
| # from src.data.collator.eval_collator import MultimodalEvalDataCollator | |
| # from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset | |
| # from src.eval_utils.metrics import RankingMetrics | |
| # from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel | |
| # from src.model.processor import get_backbone_name, load_processor, COLPALI | |
| # from src.utils import batch_to_device, print_rank, print_master | |
| # # 引入分类器 | |
| # from src.classifier_utils_V2 import EarlyExitClassifier | |
| # logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') | |
| # logger = logging.getLogger(__name__) | |
| # # =========================== | |
| # # Helper Functions (AOP Config) | |
| # # =========================== | |
| # def _parse_bool(v: str, default=False): | |
| # if v is None: return default | |
| # v = v.strip().lower() | |
| # return v in {"1","true","yes","y","t","on"} | |
| # def _parse_float(v: str, default=None): | |
| # try: return float(v) if v is not None else default | |
| # except: return default | |
| # def _parse_int(v: str, default=None): | |
| # try: return int(v) if v is not None else default | |
| # except: return default | |
| # def get_env_aop_config(): | |
| # enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) | |
| # apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() | |
| # layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) | |
| # mode = os.environ.get("AOP_MODE", "delta").strip().lower() | |
| # # Parameters | |
| # delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) | |
| # khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) | |
| # keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) | |
| # min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) | |
| # use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) | |
| # # Specific | |
| # prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) | |
| # prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) | |
| # keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) | |
| # keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) | |
| # selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() | |
| # attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() | |
| # if layer_idx is None and enabled: | |
| # enabled = False | |
| # return { | |
| # "enabled": enabled, "apply_to": apply_to, "layer_idx": layer_idx, "mode": mode, | |
| # "delta": delta, "K_hat": khat, "keep_ratio": keep_ratio, "min_keep": min_keep, | |
| # "use_bias": use_bias, "prune_vision": prune_vision, "prune_text": prune_text, | |
| # "keep_ratio_vision": keep_ratio_v, "keep_ratio_text": keep_ratio_t, | |
| # "selection": selection, "attn_agg": attn_agg, | |
| # "margin_mid": None # Simplified | |
| # } | |
| # def get_env_ee_config(): | |
| # ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"} | |
| # layer = int(os.environ.get("EE_LAYER", "12")) | |
| # method = os.environ.get("EE_METHOD", "classifier").strip().lower() | |
| # threshold = float(os.environ.get("EE_THRESHOLD", "0.8")) | |
| # classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "") | |
| # return dict(enabled=ee_enabled, layer=layer, method=method, threshold=threshold, classifier_path=classifier_path) | |
| # def pad_dataset_to_divisible(dataset, world_size): | |
| # num_samples = len(dataset) | |
| # if num_samples % world_size == 0: return dataset, num_samples | |
| # num_to_add = world_size - (num_samples % world_size) | |
| # padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) | |
| # return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add | |
| # # =========================== | |
| # # Core Inference Function | |
| # # =========================== | |
| # def run_early_exit_queries( | |
| # model: MMEBModel, | |
| # classifier: EarlyExitClassifier, | |
| # processor, | |
| # model_args: ModelArguments, | |
| # data_args: DataArguments, | |
| # training_args: TrainingArguments, | |
| # qry_dataset: Dataset, | |
| # cand_mid_dict: dict, | |
| # cand_last_dict: dict, | |
| # ee_cfg: dict, | |
| # dataset_name: str, | |
| # out_dir: str, | |
| # global_ranking: bool = True, | |
| # ): | |
| # device = training_args.device | |
| # local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| # is_main = (not dist.is_initialized()) or (local_rank == 0) | |
| # # ========================== | |
| # # Profiling 配置 | |
| # # ========================== | |
| # profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in { | |
| # "1", "true", "yes", "on", "y", "t" | |
| # } | |
| # topk_emb = int(os.environ.get("EE_TOPK_EMB", "5")) | |
| # # 运行时间统计(按样本数加权) | |
| # timing_stats = { | |
| # "mid_time_sum": 0.0, # encoder 到 mid 的总时间 * 样本数 | |
| # "mid_num": 0, # 样本数 | |
| # "tail_time_sum": 0.0, # mid->last 续跑部分的总时间 * 样本数 | |
| # "tail_num": 0, # 续跑样本数 | |
| # } | |
| # # embedding + 相似度记录(仅 rank0 保存) | |
| # analysis_records = [] if (profile_enabled and is_main) else None | |
| # # 1. 准备 Candidates | |
| # cand_ids = list(cand_mid_dict.keys()) | |
| # cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) | |
| # cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) | |
| # cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) | |
| # cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16) | |
| # collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") | |
| # loader = DataLoader( | |
| # qry_dataset, | |
| # batch_size=training_args.per_device_eval_batch_size, | |
| # collate_fn=collator, | |
| # num_workers=training_args.dataloader_num_workers | |
| # ) | |
| # pred_dicts = [] | |
| # stats = {"exit": 0, "total": 0} | |
| # threshold = float(ee_cfg["threshold"]) | |
| # method = ee_cfg["method"] | |
| # target_layer_idx = int(ee_cfg["layer"]) | |
| # # 结果顺序 | |
| # results_dict = {} | |
| # global_sample_idx = 0 | |
| # # Local vs Global ranking | |
| # use_local = (not global_ranking) | |
| # if use_local: | |
| # print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)") | |
| # cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} | |
| # else: | |
| # print_master(f"[INFO] Using GLOBAL ranking (full library)") | |
| # # --- AOP 配置初始化 --- | |
| # aop_cfg = getattr(model.encoder, "aop_prune_config", None) | |
| # _orig_enabled = None | |
| # side_enable = True | |
| # if isinstance(aop_cfg, dict) and aop_cfg: | |
| # _orig_enabled = aop_cfg.get("enabled", False) | |
| # apply_to = aop_cfg.get("apply_to", "qry") | |
| # side_enable = (apply_to == "both") or (apply_to == "qry") | |
| # model.eval() | |
| # if classifier: | |
| # classifier.eval() | |
| # classifier.to(device) | |
| # start_time = time.time() | |
| # for inputs, infos in tqdm( | |
| # loader, | |
| # desc=f"[EE+AOP] {dataset_name} (tau={threshold})", | |
| # disable=local_rank > 0, | |
| # ): | |
| # inputs = batch_to_device(inputs, device) | |
| # B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1 | |
| # batch_start_idx = global_sample_idx | |
| # global_sample_idx += B | |
| # stats["total"] += B | |
| # # --------------------------------------------------- | |
| # # 1. 前半程: Run to Mid Layer (含 AOP 动态控制) | |
| # # --------------------------------------------------- | |
| # orig_cfg = None | |
| # if isinstance(aop_cfg, dict) and aop_cfg: | |
| # orig_cfg = dict(aop_cfg) | |
| # aop_layer = aop_cfg.get("layer_idx", None) | |
| # aop_on_mid = bool( | |
| # _orig_enabled | |
| # and side_enable | |
| # and (aop_layer is not None) | |
| # and (aop_layer < target_layer_idx) | |
| # ) | |
| # aop_cfg_mid = dict(aop_cfg) | |
| # aop_cfg_mid["enabled"] = aop_on_mid | |
| # setattr(model.encoder, "aop_prune_config", aop_cfg_mid) | |
| # # 计时:encoder 到 mid 层 | |
| # if profile_enabled: | |
| # torch.cuda.synchronize() | |
| # t0_mid = time.perf_counter() | |
| # with torch.no_grad(), torch.autocast( | |
| # device_type="cuda", dtype=torch.bfloat16, enabled=True | |
| # ): | |
| # out_mid = model.encoder( | |
| # **inputs, | |
| # return_dict=True, | |
| # output_hidden_states=False, | |
| # stop_at_layer=target_layer_idx, | |
| # compute_lm_head=False, | |
| # ) | |
| # if profile_enabled: | |
| # torch.cuda.synchronize() | |
| # t1_mid = time.perf_counter() | |
| # timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B | |
| # timing_stats["mid_num"] += B | |
| # # 恢复 AOP 配置 | |
| # if isinstance(orig_cfg, dict): | |
| # setattr(model.encoder, "aop_prune_config", orig_cfg) | |
| # # Hidden State & Mask | |
| # hs_mid = getattr(out_mid, "last_hidden_state", None) | |
| # if hs_mid is None: | |
| # hs_mid = out_mid.hidden_states[-1] | |
| # am_mid = getattr(out_mid, "attention_mask", None) | |
| # if am_mid is None: | |
| # am_mid = inputs.get("attention_mask", None) | |
| # reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(dtype=torch.bfloat16) | |
| # # 如果要记录最后一层 embedding,需要在 profiling 模式下额外跑一次全 forward | |
| # reps_last_full = None | |
| # if profile_enabled: | |
| # with torch.no_grad(), torch.autocast( | |
| # device_type="cuda", dtype=torch.bfloat16, enabled=True | |
| # ): | |
| # out_full = model.encoder( | |
| # **inputs, | |
| # return_dict=True, | |
| # output_hidden_states=False, | |
| # stop_at_layer=None, | |
| # compute_lm_head=False, | |
| # ) | |
| # hs_last_full = getattr(out_full, "last_hidden_state", None) | |
| # if hs_last_full is None: | |
| # hs_last_full = out_full.hidden_states[-1] | |
| # am_last_full = getattr(out_full, "attention_mask", None) | |
| # if am_last_full is None: | |
| # am_last_full = inputs.get("attention_mask", None) | |
| # reps_last_full = ( | |
| # model._pooling(hs_last_full, am_last_full) | |
| # .detach() | |
| # .to(dtype=torch.bfloat16) | |
| # ) | |
| # # --------------------------------------------------- | |
| # # 2. 特征工程 + gating | |
| # # --------------------------------------------------- | |
| # exit_mask = np.zeros(B, dtype=bool) | |
| # p_need_last_batch = None # 仅 profiling 或 debug 时保存 | |
| # if method == "classifier" and classifier is not None: | |
| # with torch.no_grad(): | |
| # # 【关键修复】特征提取:qry_mid × cand_mid(表征空间对齐) | |
| # cos_mid = reps_mid_t @ cand_mid_t.T # [B, N] | |
| # backbone_ptr = ( | |
| # model.module if hasattr(model, "module") else model | |
| # ) | |
| # temp = getattr(backbone_ptr, "temperature", 0.02) | |
| # scores_mid = cos_mid / temp | |
| # probs_mid = torch.softmax(scores_mid, dim=1) # [B, N] | |
| # diag_cos = cos_mid.max(dim=1)[0] | |
| # sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True) | |
| # s2_cos = ( | |
| # sorted_cos[:, 1] | |
| # if sorted_cos.size(1) > 1 | |
| # else sorted_cos[:, 0] | |
| # ) | |
| # margin_mid = diag_cos - s2_cos | |
| # margin_mean = margin_mid.mean() | |
| # margin_std = margin_mid.std(unbiased=False) + 1e-6 | |
| # z_margin_mid = (margin_mid - margin_mean) / margin_std | |
| # margin_median = margin_mid.median() | |
| # mad = (margin_mid - margin_median).abs().median() + 1e-6 | |
| # mad_margin_mid = (margin_mid - margin_median) / mad | |
| # p1_mid = probs_mid.max(dim=1)[0] | |
| # H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1) | |
| # gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1) | |
| # TOPK = min(16, probs_mid.size(1)) | |
| # topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1) | |
| # topk_mean = topk_vals.mean(dim=1) | |
| # topk_std = topk_vals.std(dim=1, unbiased=False) | |
| # topk_cv = topk_std / (topk_mean + 1e-6) | |
| # centered = topk_vals - topk_mean.unsqueeze(1) | |
| # var = (centered ** 2).mean(dim=1) + 1e-6 | |
| # m4 = (centered ** 4).mean(dim=1) | |
| # topk_kurt = m4 / (var ** 2) | |
| # topk_med = topk_vals.median(dim=1).values | |
| # row_mean_cos = cos_mid.mean(dim=1) | |
| # row_med_cos = cos_mid.median(dim=1).values | |
| # s1_over_mean = diag_cos - row_mean_cos | |
| # s1_over_med = diag_cos - row_med_cos | |
| # sorted_probs, _ = torch.sort( | |
| # probs_mid, dim=1, descending=True | |
| # ) | |
| # p1 = sorted_probs[:, 0] | |
| # p2 = ( | |
| # sorted_probs[:, 1] | |
| # if sorted_probs.size(1) > 1 | |
| # else sorted_probs[:, 0] | |
| # ) | |
| # shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum( | |
| # dim=1 | |
| # ) | |
| # shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1) | |
| # R = min(10, sorted_probs.size(1)) | |
| # x = torch.arange( | |
| # R, device=device, dtype=sorted_probs.dtype | |
| # ) | |
| # x_centered = x - x.mean() | |
| # denom = (x_centered ** 2).sum() | |
| # y = torch.log(sorted_probs[:, :R] + 1e-6) | |
| # slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom | |
| # row_mean_p = probs_mid.mean(dim=1) | |
| # row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6 | |
| # z1 = (p1_mid - row_mean_p) / row_std_p | |
| # center_p = probs_mid - row_mean_p.unsqueeze(1) | |
| # m3 = (center_p ** 3).mean(dim=1) | |
| # skew = m3 / (row_std_p ** 3 + 1e-6) | |
| # s1_over_sk = p1_mid - skew | |
| # TAIL_K = min(10, sorted_probs.size(1)) | |
| # tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1) | |
| # HEAD_K = min(5, sorted_probs.size(1)) | |
| # head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1) | |
| # mask_ratio = torch.zeros_like(diag_cos) | |
| # mask_len = torch.zeros_like(diag_cos) | |
| # mask_runs = torch.zeros_like(diag_cos) | |
| # scalar_inputs = torch.stack( | |
| # [ | |
| # diag_cos, | |
| # s2_cos, | |
| # margin_mid, | |
| # z_margin_mid, | |
| # mad_margin_mid, | |
| # p1_mid, | |
| # H_mid, | |
| # gini_mid, | |
| # topk_mean, | |
| # topk_std, | |
| # topk_cv, | |
| # topk_kurt, | |
| # topk_med, | |
| # s1_over_mean, | |
| # s1_over_med, | |
| # p1, | |
| # p2, | |
| # shape_H, | |
| # shape_gini, | |
| # slope, | |
| # z1, | |
| # s1_over_sk, | |
| # tail_mean, | |
| # head5_mean, | |
| # mask_ratio, | |
| # mask_len, | |
| # mask_runs, | |
| # ], | |
| # dim=1, | |
| # ) | |
| # modality_idx = torch.zeros( | |
| # B, dtype=torch.long, device=device | |
| # ) | |
| # if "pixel_values" in inputs and inputs["pixel_values"] is not None: | |
| # pv = inputs["pixel_values"] | |
| # if isinstance(pv, list): | |
| # for i, item in enumerate(pv): | |
| # if item is not None: | |
| # modality_idx[i] = 1 | |
| # elif isinstance(pv, torch.Tensor) and pv.numel() > 0: | |
| # modality_idx.fill_(1) | |
| # logits = classifier(scalar_inputs, modality_idx) | |
| # p_need_last = torch.sigmoid(logits) # [B,1] | |
| # p_need_last_batch = p_need_last.squeeze(1) # [B] | |
| # should_exit = p_need_last_batch < threshold | |
| # exit_mask = should_exit.cpu().numpy() | |
| # if stats["total"] <= B * 3 and is_main: | |
| # print_master( | |
| # f"[EE Debug] Batch {stats['total']//B}: " | |
| # f"p_need_last mean={p_need_last_batch.mean().item():.4f}, " | |
| # f"std={p_need_last_batch.std().item():.4f}, " | |
| # f"Exit Rate={exit_mask.mean():.2%}, " | |
| # f"Top3 Feats: diag_cos={diag_cos.mean():.3f}, " | |
| # f"margin={margin_mid.mean():.3f}, H={H_mid.mean():.3f}" | |
| # ) | |
| # stats["exit"] += exit_mask.sum() | |
| # # --------------------------------------------------- | |
| # # 3. 分支执行 | |
| # # --------------------------------------------------- | |
| # exit_indices = np.where(exit_mask)[0] | |
| # cont_indices = np.where(~exit_mask)[0] | |
| # # A. 早停样本:用 mid→last 检索 | |
| # if len(exit_indices) > 0: | |
| # reps_exit = reps_mid_t[exit_indices] | |
| # if use_local: | |
| # for i, idx in enumerate(exit_indices): | |
| # cand_local = infos[idx].get("cand_names", []) | |
| # if not cand_local: | |
| # cids = [] | |
| # else: | |
| # rows = [ | |
| # cand_id2row.get(str(cid), -1) | |
| # for cid in cand_local | |
| # ] | |
| # rows = [r for r in rows if r >= 0] | |
| # if len(rows) == 0: | |
| # cids = [] | |
| # else: | |
| # # 【关键修复】早停检索:qry_mid × cand_mid(表征空间对齐) | |
| # cmat_t = cand_mid_t[rows] | |
| # scores_vec = (reps_exit[i : i + 1] @ cmat_t.T)[0] | |
| # order_local = ( | |
| # torch.argsort( | |
| # scores_vec, | |
| # dim=0, | |
| # descending=True, | |
| # ) | |
| # .cpu() | |
| # .numpy() | |
| # ) | |
| # cids = [str(cand_local[o]) for o in order_local] | |
| # label = ( | |
| # infos[idx].get("label_name") | |
| # or infos[idx].get("label") | |
| # or infos[idx].get("rel_docids") | |
| # ) | |
| # if not isinstance(label, list): | |
| # label = [label] | |
| # rel_scores = infos[idx].get("rel_scores", None) | |
| # global_idx = batch_start_idx + idx | |
| # results_dict[global_idx] = { | |
| # "prediction": cids, | |
| # "label": label, | |
| # "rel_scores": rel_scores, | |
| # } | |
| # else: | |
| # # 【关键修复】早停检索:qry_mid × cand_mid(表征空间对齐) | |
| # scores_full = reps_exit @ cand_mid_t.T | |
| # _, topk_inds = torch.topk( | |
| # scores_full, k=min(200, len(cand_ids)), dim=1 | |
| # ) | |
| # topk_inds = topk_inds.cpu().numpy() | |
| # for i, idx in enumerate(exit_indices): | |
| # cids = [cand_ids[k] for k in topk_inds[i]] | |
| # label = ( | |
| # infos[idx].get("label_name") | |
| # or infos[idx].get("label") | |
| # or infos[idx].get("rel_docids") | |
| # ) | |
| # if not isinstance(label, list): | |
| # label = [label] | |
| # rel_scores = infos[idx].get("rel_scores", None) | |
| # global_idx = batch_start_idx + idx | |
| # results_dict[global_idx] = { | |
| # "prediction": cids, | |
| # "label": label, | |
| # "rel_scores": rel_scores, | |
| # } | |
| # # B. 续跑样本:从 mid 继续到 last | |
| # if len(cont_indices) > 0: | |
| # interm = getattr(out_mid, "intermediate_state", None) | |
| # hs = interm["hidden_states"].detach() | |
| # am = interm["attention_mask"].detach() | |
| # pos = interm["position_ids"].detach() | |
| # vm = interm.get("vision_mask", None) | |
| # tm = interm.get("text_mask", None) | |
| # next_layer = int(interm["next_layer_idx"]) | |
| # resume_state_subset = { | |
| # "hidden_states": hs[cont_indices], | |
| # "attention_mask": am[cont_indices], | |
| # "position_ids": pos[:, cont_indices, :], | |
| # "vision_mask": vm[cont_indices] if vm is not None else None, | |
| # "text_mask": tm[cont_indices] if tm is not None else None, | |
| # "next_layer_idx": next_layer, | |
| # } | |
| # if isinstance(aop_cfg, dict) and aop_cfg: | |
| # aop_resume = dict(aop_cfg) | |
| # aop_resume["enabled"] = bool(_orig_enabled and side_enable) | |
| # setattr(model.encoder, "aop_prune_config", aop_resume) | |
| # # 计时:mid -> last | |
| # if profile_enabled: | |
| # torch.cuda.synchronize() | |
| # t0_tail = time.perf_counter() | |
| # with torch.no_grad(), torch.autocast( | |
| # device_type="cuda", dtype=torch.bfloat16, enabled=True | |
| # ): | |
| # out_last = model.encoder( | |
| # return_dict=True, | |
| # output_hidden_states=False, | |
| # stop_at_layer=None, | |
| # resume_state=resume_state_subset, | |
| # compute_lm_head=False, | |
| # ) | |
| # if profile_enabled: | |
| # torch.cuda.synchronize() | |
| # t1_tail = time.perf_counter() | |
| # timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len( | |
| # cont_indices | |
| # ) | |
| # timing_stats["tail_num"] += len(cont_indices) | |
| # hs_last = out_last.last_hidden_state | |
| # if hs_last is None: | |
| # hs_last = out_last.hidden_states[-1] | |
| # am_last = getattr(out_last, "attention_mask", None) | |
| # if am_last is None: | |
| # am_last = resume_state_subset["attention_mask"] | |
| # reps_last_t = ( | |
| # model._pooling(hs_last, am_last) | |
| # .detach() | |
| # .to(dtype=torch.bfloat16) | |
| # ) | |
| # if use_local: | |
| # for i, idx_global in enumerate(cont_indices): | |
| # cand_local = infos[idx_global].get("cand_names", []) | |
| # if not cand_local: | |
| # cids = [] | |
| # else: | |
| # rows = [ | |
| # cand_id2row.get(str(cid), -1) | |
| # for cid in cand_local | |
| # ] | |
| # rows = [r for r in rows if r >= 0] | |
| # if len(rows) == 0: | |
| # cids = [] | |
| # else: | |
| # cmat_last_t = cand_last_t[rows] | |
| # scores_vec_t = ( | |
| # reps_last_t[i : i + 1] @ cmat_last_t.T | |
| # )[0] | |
| # order_local = ( | |
| # torch.argsort( | |
| # scores_vec_t, | |
| # dim=0, | |
| # descending=True, | |
| # ) | |
| # .cpu() | |
| # .numpy() | |
| # ) | |
| # cids = [str(cand_local[o]) for o in order_local] | |
| # label = ( | |
| # infos[idx_global].get("label_name") | |
| # or infos[idx_global].get("label") | |
| # or infos[idx_global].get("rel_docids") | |
| # ) | |
| # if not isinstance(label, list): | |
| # label = [label] | |
| # rel_scores = infos[idx_global].get("rel_scores", None) | |
| # global_idx = batch_start_idx + idx_global | |
| # results_dict[global_idx] = { | |
| # "prediction": cids, | |
| # "label": label, | |
| # "rel_scores": rel_scores, | |
| # } | |
| # else: | |
| # scores_last = reps_last_t @ cand_last_t.T | |
| # _, topk_inds = torch.topk( | |
| # scores_last, k=min(200, len(cand_ids)), dim=1 | |
| # ) | |
| # topk_inds = topk_inds.cpu().numpy() | |
| # for i, idx_global in enumerate(cont_indices): | |
| # cids = [cand_ids[k] for k in topk_inds[i]] | |
| # label = ( | |
| # infos[idx_global].get("label_name") | |
| # or infos[idx_global].get("label") | |
| # or infos[idx_global].get("rel_docids") | |
| # ) | |
| # if not isinstance(label, list): | |
| # label = [label] | |
| # rel_scores = infos[idx_global].get("rel_scores", None) | |
| # global_idx = batch_start_idx + idx_global | |
| # results_dict[global_idx] = { | |
| # "prediction": cids, | |
| # "label": label, | |
| # "rel_scores": rel_scores, | |
| # } | |
| # # --------------------------------------------------- | |
| # # 4. Profiling: 保存 per-query 的 topK embedding & 相似度 | |
| # # --------------------------------------------------- | |
| # if profile_enabled and is_main: | |
| # K = min(topk_emb, cand_mid_t.size(0)) | |
| # # 转到 float32 + CPU 便于写盘 | |
| # q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D] | |
| # q_last_cpu = ( | |
| # reps_last_full.detach().float().cpu() | |
| # if reps_last_full is not None | |
| # else None | |
| # ) # [B, D] | |
| # cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D] | |
| # cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D] | |
| # # 【关键修复】mid2mid 相似度(表征空间对齐) | |
| # scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc] | |
| # topk_mid_vals, topk_mid_inds = torch.topk( | |
| # scores_mid_full, k=K, dim=1 | |
| # ) | |
| # # last2last 相似度(如果有) | |
| # if q_last_cpu is not None: | |
| # scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc] | |
| # topk_last_vals, topk_last_inds = torch.topk( | |
| # scores_last_full, k=K, dim=1 | |
| # ) | |
| # else: | |
| # topk_last_vals = None | |
| # topk_last_inds = None | |
| # for i in range(B): | |
| # qid = batch_start_idx + i | |
| # rec = { | |
| # "qid": int(qid), | |
| # "early_exit": bool(exit_mask[i]), | |
| # } | |
| # if p_need_last_batch is not None: | |
| # rec["p_need_last"] = float(p_need_last_batch[i].item()) | |
| # # 【关键修复】mid2mid TopK(表征空间对齐) | |
| # mid_inds = topk_mid_inds[i].tolist() | |
| # mid_scores = topk_mid_vals[i].tolist() | |
| # rec["mid_topk_scores"] = mid_scores | |
| # rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds] | |
| # rec["mid_q_emb"] = q_mid_cpu[i].tolist() | |
| # rec["mid_cand_embs"] = cand_mid_cpu[mid_inds].tolist() | |
| # # last2last TopK(如果有) | |
| # if topk_last_inds is not None: | |
| # last_inds = topk_last_inds[i].tolist() | |
| # last_scores = topk_last_vals[i].tolist() | |
| # rec["last_topk_scores"] = last_scores | |
| # rec["last_topk_cand_ids"] = [ | |
| # cand_ids[j] for j in last_inds | |
| # ] | |
| # rec["last_q_emb"] = ( | |
| # q_last_cpu[i].tolist() if q_last_cpu is not None else None | |
| # ) | |
| # rec["last_cand_embs"] = cand_last_cpu[last_inds].tolist() | |
| # else: | |
| # rec["last_topk_scores"] = None | |
| # rec["last_topk_cand_ids"] = None | |
| # rec["last_q_emb"] = None | |
| # rec["last_cand_embs"] = None | |
| # analysis_records.append(rec) | |
| # # ===================================================== | |
| # # 5. 收集 & 保存结果 | |
| # # ===================================================== | |
| # for idx in sorted(results_dict.keys()): | |
| # pred_dicts.append(results_dict[idx]) | |
| # print_master( | |
| # f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} " | |
| # f"({stats['exit']/stats['total']:.2%})" | |
| # ) | |
| # metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] | |
| # score = RankingMetrics(metrics_to_report).evaluate(pred_dicts) | |
| # if is_main: | |
| # os.makedirs(out_dir, exist_ok=True) | |
| # with open( | |
| # os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), | |
| # "w", | |
| # ) as f: | |
| # json.dump(score, f, indent=4) | |
| # # 保存 profiling 信息 | |
| # if profile_enabled: | |
| # prof_dir = os.path.join(out_dir, "profiling") | |
| # os.makedirs(prof_dir, exist_ok=True) | |
| # # 时间统计:转换成均值(秒/样本) | |
| # mid_avg = ( | |
| # timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"]) | |
| # ) | |
| # tail_avg = ( | |
| # timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"]) | |
| # ) | |
| # timing_out = { | |
| # "mid_time_sum": timing_stats["mid_time_sum"], | |
| # "mid_num": timing_stats["mid_num"], | |
| # "tail_time_sum": timing_stats["tail_time_sum"], | |
| # "tail_num": timing_stats["tail_num"], | |
| # "avg_mid_time_per_query_sec": mid_avg, | |
| # "avg_tail_time_per_cont_query_sec": tail_avg, | |
| # "num_exit": int(stats["exit"]), | |
| # "num_total": int(stats["total"]), | |
| # } | |
| # with open( | |
| # os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w" | |
| # ) as f: | |
| # json.dump(timing_out, f, indent=2) | |
| # # embedding + 相似度记录(JSONL) | |
| # embed_path = os.path.join( | |
| # prof_dir, f"{dataset_name}_embeds.jsonl" | |
| # ) | |
| # with open(embed_path, "w") as f: | |
| # for rec in analysis_records: | |
| # f.write(json.dumps(rec) + "\n") | |
| # print_master( | |
| # f"[PROFILE] Saved timing to {prof_dir}, " | |
| # f"embeddings to {embed_path}" | |
| # ) | |
| # elapsed = time.time() - start_time | |
| # return score, elapsed | |
| # # =========================== | |
| # # Helper Functions (Pre-Computation) | |
| # # =========================== | |
| # def encode_candidates_both_layers( | |
| # model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, | |
| # model_args: ModelArguments, full_dataset: Dataset, mid_layer: int, | |
| # ): | |
| # local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| # model.eval() | |
| # all_mid, all_last, all_ids = [], [], [] | |
| # with torch.no_grad(): | |
| # for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0): | |
| # inputs = batch_to_device(inputs, training_args.device) | |
| # # 强制关闭 Cand 侧的 AOP | |
| # aop_cfg = getattr(model.encoder, "aop_prune_config", None) | |
| # _orig = None | |
| # if isinstance(aop_cfg, dict): | |
| # _orig = aop_cfg.get("enabled", False) | |
| # aop_cfg["enabled"] = False | |
| # setattr(model.encoder, "aop_prune_config", aop_cfg) | |
| # with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): | |
| # out = model.encoder(**inputs, return_dict=True, output_hidden_states=True, stop_at_layer=None) | |
| # if isinstance(aop_cfg, dict) and _orig is not None: | |
| # aop_cfg["enabled"] = _orig # Restore | |
| # mid_hs = out.hidden_states[mid_layer] | |
| # last_hs = out.hidden_states[-1] | |
| # am = inputs.get("attention_mask", None) | |
| # if am is not None and am.device != mid_hs.device: am = am.to(mid_hs.device) | |
| # reps_mid = model._pooling(mid_hs, am) | |
| # reps_last = model._pooling(last_hs, am) | |
| # all_mid.append(reps_mid.detach().float().cpu()) | |
| # all_last.append(reps_last.detach().float().cpu()) | |
| # all_ids.extend([info["cand_name"] for info in dataset_info]) | |
| # if not all_mid: return np.array([]), np.array([]), [] | |
| # return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids | |
| # # =========================== | |
| # # Main | |
| # # =========================== | |
| # def main(): | |
| # if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): | |
| # dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) | |
| # local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| # parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) | |
| # model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| # ee_cfg = get_env_ee_config() | |
| # hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| # if not getattr(model_args, "model_backbone", None): | |
| # model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) | |
| # setattr(model_args, 'model_backbone', model_backbone) | |
| # if local_rank == 0: | |
| # processor = load_processor(model_args, data_args) | |
| # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) | |
| # if torch.distributed.is_initialized(): torch.distributed.barrier() | |
| # if local_rank != 0: | |
| # processor = load_processor(model_args, data_args) | |
| # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) | |
| # model.eval() | |
| # model = model.to(training_args.device, dtype=torch.bfloat16) | |
| # # 调试:检查模型配置 | |
| # print_master(f"[DEBUG] Model normalize={model.normalize}, pooling={model.pooling}, temperature={getattr(model, 'temperature', 'N/A')}") | |
| # # AOP 配置注入 | |
| # aop_cfg = get_env_aop_config() | |
| # if aop_cfg["enabled"]: | |
| # setattr(model.encoder, "aop_prune_config", aop_cfg) | |
| # model.set_inference_layers(qry_layers=None, tgt_layers=None) | |
| # print_master(f"[AOP] Enabled: Layer={aop_cfg['layer_idx']}, Ratio={aop_cfg['keep_ratio']}") | |
| # # 加载分类器 | |
| # classifier = None | |
| # if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]: | |
| # classifier_path = ee_cfg['classifier_path'] | |
| # print_master(f"[EE] Loading Classifier from {classifier_path}...") | |
| # # 【关键】使用27维特征,与训练代码完全一致 | |
| # classifier = EarlyExitClassifier(input_dim=27, hidden_dim=64) | |
| # state_dict = None | |
| # # 尝试多种加载方式 | |
| # if os.path.isdir(classifier_path): | |
| # # 方式1: safetensors 格式(Trainer 默认) | |
| # safetensors_file = os.path.join(classifier_path, "model.safetensors") | |
| # if os.path.exists(safetensors_file): | |
| # from safetensors.torch import load_file | |
| # state_dict = load_file(safetensors_file) | |
| # print_master(f"[EE] Loaded from model.safetensors") | |
| # else: | |
| # # 方式2: pytorch_model.bin | |
| # pt_file = os.path.join(classifier_path, "pytorch_model.bin") | |
| # if os.path.exists(pt_file): | |
| # state_dict = torch.load(pt_file, map_location=training_args.device) | |
| # print_master(f"[EE] Loaded from pytorch_model.bin") | |
| # else: | |
| # # 方式3: 独立的 .pt 文件 | |
| # layer_idx = ee_cfg.get('layer', 12) | |
| # pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt") | |
| # if os.path.exists(pt_file): | |
| # state_dict = torch.load(pt_file, map_location=training_args.device) | |
| # print_master(f"[EE] Loaded from {os.path.basename(pt_file)}") | |
| # else: | |
| # raise FileNotFoundError(f"Cannot find classifier weights in {classifier_path}") | |
| # elif os.path.isfile(classifier_path): | |
| # # 方式4: 直接指定文件 | |
| # if classifier_path.endswith('.safetensors'): | |
| # from safetensors.torch import load_file | |
| # state_dict = load_file(classifier_path) | |
| # print_master(f"[EE] Loaded from .safetensors file") | |
| # else: | |
| # state_dict = torch.load(classifier_path, map_location=training_args.device) | |
| # print_master(f"[EE] Loaded from .pt file") | |
| # else: | |
| # raise FileNotFoundError(f"Classifier path not found: {classifier_path}") | |
| # classifier.load_state_dict(state_dict) | |
| # classifier.to(training_args.device) | |
| # classifier.eval() | |
| # print_master(f"[EE] Classifier loaded successfully. Threshold={ee_cfg['threshold']}") | |
| # with open(data_args.dataset_config, 'r') as yaml_file: | |
| # dataset_configs = yaml.safe_load(yaml_file) | |
| # for dataset_name, task_config in dataset_configs.items(): | |
| # if dist.is_initialized(): dist.barrier() | |
| # print_master(f"\n--- Evaluating {dataset_name} ---") | |
| # if data_args.data_basedir: | |
| # for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: | |
| # if task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) | |
| # full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) | |
| # full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) | |
| # mid_layer = int(ee_cfg["layer"]) | |
| # cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}") | |
| # cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast") | |
| # # 【关键修复】强制重新生成candidates以确保normalize设置一致 | |
| # force_regenerate = False # 已经重新生成过,设为False避免每次都重新计算 | |
| # # 预计算 Candidates | |
| # if force_regenerate or (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)): | |
| # print_master(f"[INFO] Regenerating candidates with normalize={model.normalize}...") | |
| # eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") | |
| # eval_cand_loader = DataLoader(full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) | |
| # cand_mid, cand_last, cand_ids = encode_candidates_both_layers(model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer) | |
| # if local_rank == 0: | |
| # with open(cand_mid_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f) | |
| # with open(cand_last_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f) | |
| # if dist.is_initialized(): dist.barrier() | |
| # if local_rank == 0: | |
| # with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f) | |
| # with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f) | |
| # # 【关键修复】从task_config读取eval_type,决定global/local ranking | |
| # rank_global = task_config.get("eval_type", "global") == "global" | |
| # print_master(f"[{dataset_name}] Eval type: {'global' if rank_global else 'local'} ranking") | |
| # run_early_exit_queries( | |
| # model, classifier, processor, model_args, data_args, training_args, | |
| # full_eval_qry_dataset, cand_mid_dict, cand_last_dict, ee_cfg, dataset_name, data_args.encode_output_path, | |
| # global_ranking=rank_global | |
| # ) | |
| # if dist.is_initialized(): dist.barrier() | |
| # if __name__ == '__main__': | |
| # main() | |
| # ################################################################################### | |
| # # 添加早停率上下界 | |
| # 调用格式: | |
| # export EE_EXIT_LOWER=0.05 | |
| # export EE_EXIT_UPPER=0.5 | |
| # import datetime | |
| # import logging | |
| # import json | |
| # import random | |
| # import time | |
| # import numpy as np | |
| # import os | |
| # import pickle | |
| # import sys | |
| # import torch | |
| # import torch.distributed as dist | |
| # import torch.nn.functional as F | |
| # import yaml | |
| # import transformers | |
| # import math | |
| # from torch.utils.data import DataLoader | |
| # from tqdm import tqdm | |
| # from transformers import HfArgumentParser, AutoConfig, AutoTokenizer | |
| # from datasets import Dataset, concatenate_datasets | |
| # from datasets.distributed import split_dataset_by_node | |
| # from src.arguments import ModelArguments, DataArguments, TrainingArguments | |
| # from src.data.collator.eval_collator import MultimodalEvalDataCollator | |
| # from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset | |
| # from src.eval_utils.metrics import RankingMetrics | |
| # from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel | |
| # from src.model.processor import get_backbone_name, load_processor, COLPALI | |
| # from src.utils import batch_to_device, print_rank, print_master | |
| # # 引入分类器 | |
| # from src.classifier_utils_V2 import EarlyExitClassifier | |
| # logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') | |
| # logger = logging.getLogger(__name__) | |
| # # =========================== | |
| # # Helper Functions (AOP Config) | |
| # # =========================== | |
| # def _parse_bool(v: str, default=False): | |
| # if v is None: return default | |
| # v = v.strip().lower() | |
| # return v in {"1","true","yes","y","t","on"} | |
| # def _parse_float(v: str, default=None): | |
| # try: return float(v) if v is not None else default | |
| # except: return default | |
| # def _parse_int(v: str, default=None): | |
| # try: return int(v) if v is not None else default | |
| # except: return default | |
| # def get_env_aop_config(): | |
| # enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) | |
| # apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() | |
| # layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) | |
| # mode = os.environ.get("AOP_MODE", "delta").strip().lower() | |
| # # Parameters | |
| # delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) | |
| # khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) | |
| # keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) | |
| # min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) | |
| # use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) | |
| # # Specific | |
| # prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) | |
| # prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) | |
| # keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) | |
| # keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) | |
| # selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() | |
| # attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() | |
| # if layer_idx is None and enabled: | |
| # enabled = False | |
| # return { | |
| # "enabled": enabled, "apply_to": apply_to, "layer_idx": layer_idx, "mode": mode, | |
| # "delta": delta, "K_hat": khat, "keep_ratio": keep_ratio, "min_keep": min_keep, | |
| # "use_bias": use_bias, "prune_vision": prune_vision, "prune_text": prune_text, | |
| # "keep_ratio_vision": keep_ratio_v, "keep_ratio_text": keep_ratio_t, | |
| # "selection": selection, "attn_agg": attn_agg, | |
| # "margin_mid": None # Simplified | |
| # } | |
| # def get_env_ee_config(): | |
| # ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"} | |
| # layer = int(os.environ.get("EE_LAYER", "12")) | |
| # method = os.environ.get("EE_METHOD", "classifier").strip().lower() | |
| # threshold = float(os.environ.get("EE_THRESHOLD", "0.8")) | |
| # classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "") | |
| # # ===== 新增:早停率上下界(0~1) ===== | |
| # exit_lower = float(os.environ.get("EE_EXIT_LOWER", "0.0")) # 下界,默认 0 | |
| # exit_upper = float(os.environ.get("EE_EXIT_UPPER", "1.0")) # 上界,默认 1 | |
| # # 合法化:保证 0 <= lower <= upper <= 1 | |
| # exit_lower = max(0.0, min(exit_lower, 1.0)) | |
| # exit_upper = max(exit_lower, min(exit_upper, 1.0)) | |
| # return dict( | |
| # enabled=ee_enabled, | |
| # layer=layer, | |
| # method=method, | |
| # threshold=threshold, | |
| # classifier_path=classifier_path, | |
| # exit_lower=exit_lower, | |
| # exit_upper=exit_upper, | |
| # ) | |
| # def pad_dataset_to_divisible(dataset, world_size): | |
| # num_samples = len(dataset) | |
| # if num_samples % world_size == 0: return dataset, num_samples | |
| # num_to_add = world_size - (num_samples % world_size) | |
| # padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) | |
| # return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add | |
| # # =========================== | |
| # # Core Inference Function | |
| # # =========================== | |
| # def run_early_exit_queries( | |
| # model: MMEBModel, | |
| # classifier: EarlyExitClassifier, | |
| # processor, | |
| # model_args: ModelArguments, | |
| # data_args: DataArguments, | |
| # training_args: TrainingArguments, | |
| # qry_dataset: Dataset, | |
| # cand_mid_dict: dict, | |
| # cand_last_dict: dict, | |
| # ee_cfg: dict, | |
| # dataset_name: str, | |
| # out_dir: str, | |
| # global_ranking: bool = True, | |
| # ): | |
| # device = training_args.device | |
| # local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| # is_main = (not dist.is_initialized()) or (local_rank == 0) | |
| # # ========================== | |
| # # 全局早停率上下界(按 bench) | |
| # # ========================== | |
| # total_samples = len(qry_dataset) | |
| # exit_lower = float(ee_cfg.get("exit_lower", 0.0)) | |
| # exit_upper = float(ee_cfg.get("exit_upper", 1.0)) | |
| # # 是否真正启用约束:如果是默认 [0,1] 就视为关闭,保持兼容 | |
| # use_exit_bounds = not (abs(exit_lower - 0.0) < 1e-6 and abs(exit_upper - 1.0) < 1e-6) | |
| # # 映射成「最少/最多早停的样本数」 | |
| # min_exit = int(math.ceil(exit_lower * total_samples)) | |
| # max_exit = int(math.floor(exit_upper * total_samples)) | |
| # # 安全裁剪 | |
| # min_exit = max(0, min(min_exit, total_samples)) | |
| # max_exit = max(min_exit, min(max_exit, total_samples)) | |
| # if is_main: | |
| # print_master( | |
| # f"[EE] Global exit-rate bounds for dataset '{dataset_name}': " | |
| # f"lower={exit_lower:.3f} ({min_exit}/{total_samples}), " | |
| # f"upper={exit_upper:.3f} ({max_exit}/{total_samples}), " | |
| # f"enabled={use_exit_bounds}" | |
| # ) | |
| # # ========================== | |
| # # Profiling 配置 | |
| # # ========================== | |
| # profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in { | |
| # "1", "true", "yes", "on", "y", "t" | |
| # } | |
| # topk_emb = int(os.environ.get("EE_TOPK_EMB", "5")) | |
| # # 运行时间统计(按样本数加权) | |
| # timing_stats = { | |
| # "mid_time_sum": 0.0, # encoder 到 mid 的总时间 * 样本数 | |
| # "mid_num": 0, # 样本数 | |
| # "tail_time_sum": 0.0, # mid->last 续跑部分的总时间 * 样本数 | |
| # "tail_num": 0, # 续跑样本数 | |
| # } | |
| # # embedding + 相似度记录(仅 rank0 保存) | |
| # analysis_records = [] if (profile_enabled and is_main) else None | |
| # # 1. 准备 Candidates | |
| # cand_ids = list(cand_mid_dict.keys()) | |
| # cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) | |
| # cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) | |
| # # 新增:作为检索打分使用的 FP32 Numpy 向量 | |
| # cand_mid_np = cand_mid # [Nc, D], float32 | |
| # cand_last_np = cand_last # [Nc, D], float32 | |
| # cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) | |
| # cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16) | |
| # collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") | |
| # loader = DataLoader( | |
| # qry_dataset, | |
| # batch_size=training_args.per_device_eval_batch_size, | |
| # collate_fn=collator, | |
| # num_workers=training_args.dataloader_num_workers | |
| # ) | |
| # pred_dicts = [] | |
| # stats = {"exit": 0, "total": 0} | |
| # threshold = float(ee_cfg["threshold"]) | |
| # method = ee_cfg["method"] | |
| # target_layer_idx = int(ee_cfg["layer"]) | |
| # # 结果顺序 | |
| # results_dict = {} | |
| # global_sample_idx = 0 | |
| # # Local vs Global ranking | |
| # use_local = (not global_ranking) | |
| # if use_local: | |
| # print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)") | |
| # cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} | |
| # else: | |
| # print_master(f"[INFO] Using GLOBAL ranking (full library)") | |
| # # --- AOP 配置初始化 --- | |
| # aop_cfg = getattr(model.encoder, "aop_prune_config", None) | |
| # _orig_enabled = None | |
| # side_enable = True | |
| # if isinstance(aop_cfg, dict) and aop_cfg: | |
| # _orig_enabled = aop_cfg.get("enabled", False) | |
| # apply_to = aop_cfg.get("apply_to", "qry") | |
| # side_enable = (apply_to == "both") or (apply_to == "qry") | |
| # model.eval() | |
| # if classifier: | |
| # classifier.eval() | |
| # classifier.to(device) | |
| # start_time = time.time() | |
| # for inputs, infos in tqdm( | |
| # loader, | |
| # desc=f"[EE+AOP] {dataset_name} (tau={threshold})", | |
| # disable=local_rank > 0, | |
| # ): | |
| # inputs = batch_to_device(inputs, device) | |
| # B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1 | |
| # batch_start_idx = global_sample_idx | |
| # global_sample_idx += B | |
| # stats["total"] += B | |
| # # --------------------------------------------------- | |
| # # 1. 前半程: Run to Mid Layer (含 AOP 动态控制) | |
| # # --------------------------------------------------- | |
| # orig_cfg = None | |
| # if isinstance(aop_cfg, dict) and aop_cfg: | |
| # orig_cfg = dict(aop_cfg) | |
| # aop_layer = aop_cfg.get("layer_idx", None) | |
| # aop_on_mid = bool( | |
| # _orig_enabled | |
| # and side_enable | |
| # and (aop_layer is not None) | |
| # and (aop_layer < target_layer_idx) | |
| # ) | |
| # aop_cfg_mid = dict(aop_cfg) | |
| # aop_cfg_mid["enabled"] = aop_on_mid | |
| # setattr(model.encoder, "aop_prune_config", aop_cfg_mid) | |
| # # 计时:encoder 到 mid 层 | |
| # if profile_enabled: | |
| # torch.cuda.synchronize() | |
| # t0_mid = time.perf_counter() | |
| # with torch.no_grad(), torch.autocast( | |
| # device_type="cuda", dtype=torch.bfloat16, enabled=True | |
| # ): | |
| # out_mid = model.encoder( | |
| # **inputs, | |
| # return_dict=True, | |
| # output_hidden_states=False, | |
| # stop_at_layer=target_layer_idx, | |
| # compute_lm_head=False, | |
| # ) | |
| # if profile_enabled: | |
| # torch.cuda.synchronize() | |
| # t1_mid = time.perf_counter() | |
| # timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B | |
| # timing_stats["mid_num"] += B | |
| # # 恢复 AOP 配置 | |
| # if isinstance(orig_cfg, dict): | |
| # setattr(model.encoder, "aop_prune_config", orig_cfg) | |
| # # Hidden State & Mask | |
| # hs_mid = getattr(out_mid, "last_hidden_state", None) | |
| # if hs_mid is None: | |
| # hs_mid = out_mid.hidden_states[-1] | |
| # am_mid = getattr(out_mid, "attention_mask", None) | |
| # if am_mid is None: | |
| # am_mid = inputs.get("attention_mask", None) | |
| # reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(dtype=torch.bfloat16) | |
| # # 如果要记录最后一层 embedding,需要在 profiling 模式下额外跑一次全 forward | |
| # reps_last_full = None | |
| # if profile_enabled: | |
| # with torch.no_grad(), torch.autocast( | |
| # device_type="cuda", dtype=torch.bfloat16, enabled=True | |
| # ): | |
| # out_full = model.encoder( | |
| # **inputs, | |
| # return_dict=True, | |
| # output_hidden_states=False, | |
| # stop_at_layer=None, | |
| # compute_lm_head=False, | |
| # ) | |
| # hs_last_full = getattr(out_full, "last_hidden_state", None) | |
| # if hs_last_full is None: | |
| # hs_last_full = out_full.hidden_states[-1] | |
| # am_last_full = getattr(out_full, "attention_mask", None) | |
| # if am_last_full is None: | |
| # am_last_full = inputs.get("attention_mask", None) | |
| # reps_last_full = ( | |
| # model._pooling(hs_last_full, am_last_full) | |
| # .detach() | |
| # .to(dtype=torch.bfloat16) | |
| # ) | |
| # # --------------------------------------------------- | |
| # # 2. 特征工程 + gating | |
| # # --------------------------------------------------- | |
| # exit_mask = np.zeros(B, dtype=bool) | |
| # p_need_last_batch = None # profiling / debug | |
| # orig_exit_mask = None # 原始按 threshold 判的结果(不带全局约束) | |
| # if method == "classifier" and classifier is not None: | |
| # with torch.no_grad(): | |
| # # 【关键修复】特征提取:qry_mid × cand_mid(表征空间对齐) | |
| # cos_mid = reps_mid_t @ cand_mid_t.T # [B, N] | |
| # backbone_ptr = ( | |
| # model.module if hasattr(model, "module") else model | |
| # ) | |
| # temp = getattr(backbone_ptr, "temperature", 0.02) | |
| # scores_mid = cos_mid / temp | |
| # probs_mid = torch.softmax(scores_mid, dim=1) # [B, N] | |
| # # ==== 以下是你原来的 27 维特征工程代码,原样保留 ==== | |
| # diag_cos = cos_mid.max(dim=1)[0] | |
| # sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True) | |
| # s2_cos = ( | |
| # sorted_cos[:, 1] | |
| # if sorted_cos.size(1) > 1 | |
| # else sorted_cos[:, 0] | |
| # ) | |
| # margin_mid = diag_cos - s2_cos | |
| # margin_mean = margin_mid.mean() | |
| # margin_std = margin_mid.std(unbiased=False) + 1e-6 | |
| # z_margin_mid = (margin_mid - margin_mean) / margin_std | |
| # margin_median = margin_mid.median() | |
| # mad = (margin_mid - margin_median).abs().median() + 1e-6 | |
| # mad_margin_mid = (margin_mid - margin_median) / mad | |
| # p1_mid = probs_mid.max(dim=1)[0] | |
| # H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1) | |
| # gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1) | |
| # TOPK = min(16, probs_mid.size(1)) | |
| # topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1) | |
| # topk_mean = topk_vals.mean(dim=1) | |
| # topk_std = topk_vals.std(dim=1, unbiased=False) | |
| # topk_cv = topk_std / (topk_mean + 1e-6) | |
| # centered = topk_vals - topk_mean.unsqueeze(1) | |
| # var = (centered ** 2).mean(dim=1) + 1e-6 | |
| # m4 = (centered ** 4).mean(dim=1) | |
| # topk_kurt = m4 / (var ** 2) | |
| # topk_med = topk_vals.median(dim=1).values | |
| # row_mean_cos = cos_mid.mean(dim=1) | |
| # row_med_cos = cos_mid.median(dim=1).values | |
| # s1_over_mean = diag_cos - row_mean_cos | |
| # s1_over_med = diag_cos - row_med_cos | |
| # sorted_probs, _ = torch.sort( | |
| # probs_mid, dim=1, descending=True | |
| # ) | |
| # p1 = sorted_probs[:, 0] | |
| # p2 = ( | |
| # sorted_probs[:, 1] | |
| # if sorted_probs.size(1) > 1 | |
| # else sorted_probs[:, 0] | |
| # ) | |
| # shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum( | |
| # dim=1 | |
| # ) | |
| # shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1) | |
| # R = min(10, sorted_probs.size(1)) | |
| # x = torch.arange( | |
| # R, device=device, dtype=sorted_probs.dtype | |
| # ) | |
| # x_centered = x - x.mean() | |
| # denom = (x_centered ** 2).sum() | |
| # y = torch.log(sorted_probs[:, :R] + 1e-6) | |
| # slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom | |
| # row_mean_p = probs_mid.mean(dim=1) | |
| # row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6 | |
| # z1 = (p1_mid - row_mean_p) / row_std_p | |
| # center_p = probs_mid - row_mean_p.unsqueeze(1) | |
| # m3 = (center_p ** 3).mean(dim=1) | |
| # skew = m3 / (row_std_p ** 3 + 1e-6) | |
| # s1_over_sk = p1_mid - skew | |
| # TAIL_K = min(10, sorted_probs.size(1)) | |
| # tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1) | |
| # HEAD_K = min(5, sorted_probs.size(1)) | |
| # head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1) | |
| # mask_ratio = torch.zeros_like(diag_cos) | |
| # mask_len = torch.zeros_like(diag_cos) | |
| # mask_runs = torch.zeros_like(diag_cos) | |
| # scalar_inputs = torch.stack( | |
| # [ | |
| # diag_cos, | |
| # s2_cos, | |
| # margin_mid, | |
| # z_margin_mid, | |
| # mad_margin_mid, | |
| # p1_mid, | |
| # H_mid, | |
| # gini_mid, | |
| # topk_mean, | |
| # topk_std, | |
| # topk_cv, | |
| # topk_kurt, | |
| # topk_med, | |
| # s1_over_mean, | |
| # s1_over_med, | |
| # p1, | |
| # p2, | |
| # shape_H, | |
| # shape_gini, | |
| # slope, | |
| # z1, | |
| # s1_over_sk, | |
| # tail_mean, | |
| # head5_mean, | |
| # mask_ratio, | |
| # mask_len, | |
| # mask_runs, | |
| # ], | |
| # dim=1, | |
| # ) | |
| # modality_idx = torch.zeros( | |
| # B, dtype=torch.long, device=device | |
| # ) | |
| # if "pixel_values" in inputs and inputs["pixel_values"] is not None: | |
| # pv = inputs["pixel_values"] | |
| # if isinstance(pv, list): | |
| # for i, item in enumerate(pv): | |
| # if item is not None: | |
| # modality_idx[i] = 1 | |
| # elif isinstance(pv, torch.Tensor) and pv.numel() > 0: | |
| # modality_idx.fill_(1) | |
| # logits = classifier(scalar_inputs, modality_idx) | |
| # p_need_last = torch.sigmoid(logits) # [B,1] | |
| # p_need_last_batch = p_need_last.squeeze(1) # [B] | |
| # # 原始按 threshold 的判定(仅作为参考) | |
| # should_exit = p_need_last_batch < threshold | |
| # orig_exit_mask = should_exit.cpu().numpy() | |
| # # ====== 应用全局早停率上下界(如果启用) ====== | |
| # if method == "classifier" and classifier is not None and orig_exit_mask is not None: | |
| # if not use_exit_bounds: | |
| # # 不启用上下界:行为与旧版本完全一致 | |
| # exit_mask = orig_exit_mask | |
| # else: | |
| # # 当前之前已经早停的数量 | |
| # E_prev = stats["exit"] | |
| # # 当前 batch 前已经处理的样本数(全局 index) | |
| # i_prev = batch_start_idx | |
| # B_cur = B | |
| # N_total = total_samples | |
| # # 最多还能早停多少个(不超过上界) | |
| # max_exit_remaining = max_exit - E_prev | |
| # if max_exit_remaining <= 0: | |
| # allowed_max = 0 | |
| # else: | |
| # allowed_max = min(B_cur, max_exit_remaining) | |
| # # 本 batch 之后还剩多少个样本没处理 | |
| # remaining_after = N_total - (i_prev + B_cur) | |
| # # 为了将来仍有机会达到下界,本 batch 至少需要早停多少个 | |
| # required_min = min( | |
| # B_cur, | |
| # max(0, min_exit - E_prev - remaining_after) | |
| # ) | |
| # # 分类器在当前 batch 原始建议早停个数(按 threshold) | |
| # n0 = int(orig_exit_mask.sum()) | |
| # if allowed_max <= 0: | |
| # n_exit = 0 | |
| # else: | |
| # # 在 [required_min, allowed_max] 区间内,尽量贴近原始 n0 | |
| # n_exit = max(required_min, min(n0, allowed_max)) | |
| # # 最终 exit_mask:在当前 batch 中选 p_need_last 最小的 n_exit 个样本早停 | |
| # exit_mask = np.zeros(B_cur, dtype=bool) | |
| # if n_exit > 0: | |
| # p_np = p_need_last_batch.detach().cpu().numpy() # [B] | |
| # order = np.argsort(p_np) # 升序(越小越容易早停) | |
| # chosen = order[:n_exit] | |
| # exit_mask[chosen] = True | |
| # # Debug:只在前几个 batch 打印 | |
| # if stats["total"] <= B * 3 and is_main: | |
| # raw_rate = orig_exit_mask.mean() | |
| # bounded_rate = exit_mask.mean() | |
| # print_master( | |
| # f"[EE Debug] Batch {stats['total']//B}: " | |
| # f"p_need_last mean={p_need_last_batch.mean().item():.4f}, " | |
| # f"std={p_need_last_batch.std().item():.4f}, " | |
| # f"raw_exit_rate={raw_rate:.2%}, " | |
| # f"bounded_exit_rate={bounded_rate:.2%}, " | |
| # f"required_min={required_min}, allowed_max={allowed_max}, " | |
| # f"n0={n0}, n_exit={n_exit}" | |
| # ) | |
| # # 非 classifier 或 classifier 不存在:保持 exit_mask=全 False(即无早停) | |
| # stats["exit"] += int(exit_mask.sum()) | |
| # # --------------------------------------------------- | |
| # # 3. 分支执行 | |
| # # --------------------------------------------------- | |
| # exit_indices = np.where(exit_mask)[0] | |
| # cont_indices = np.where(~exit_mask)[0] | |
| # # A. 早停样本:用 mid→last 检索 | |
| # if len(exit_indices) > 0: | |
| # reps_exit = reps_mid_t[exit_indices] | |
| # if use_local: | |
| # for i, idx in enumerate(exit_indices): | |
| # cand_local = infos[idx].get("cand_names", []) | |
| # if not cand_local: | |
| # cids = [] | |
| # else: | |
| # rows = [ | |
| # cand_id2row.get(str(cid), -1) | |
| # for cid in cand_local | |
| # ] | |
| # rows = [r for r in rows if r >= 0] | |
| # if len(rows) == 0: | |
| # cids = [] | |
| # else: | |
| # # CPU FP32:单个 query mid × 局部 cand_mid | |
| # cmat_np = cand_mid_np[rows] # [Nc_local, D] | |
| # qry_np = reps_exit[i].detach().float().cpu().numpy() # [D] | |
| # scores_vec = np.dot(qry_np, cmat_np.T) # [Nc_local] | |
| # top_k = min(200, len(rows)) | |
| # order_local = np.argsort(-scores_vec)[:top_k] | |
| # cids = [str(cand_local[o]) for o in order_local] | |
| # label = ( | |
| # infos[idx].get("label_name") | |
| # or infos[idx].get("label") | |
| # or infos[idx].get("rel_docids") | |
| # ) | |
| # if not isinstance(label, list): | |
| # label = [label] | |
| # rel_scores = infos[idx].get("rel_scores", None) | |
| # global_idx = batch_start_idx + idx | |
| # results_dict[global_idx] = { | |
| # "prediction": cids, | |
| # "label": label, | |
| # "rel_scores": rel_scores, | |
| # } | |
| # else: | |
| # # 使用 CPU FP32 Numpy:reps_exit_np × cand_mid_np.T | |
| # reps_exit_np = reps_exit.detach().float().cpu().numpy() # [B_exit, D] | |
| # scores_full = np.dot(reps_exit_np, cand_mid_np.T) # [B_exit, Nc] | |
| # top_k = min(200, len(cand_ids)) | |
| # # 全排序后截到 top_k | |
| # topk_inds = np.argsort(-scores_full, axis=1)[:, :top_k] # [B_exit, top_k] | |
| # for i, idx in enumerate(exit_indices): | |
| # cids = [cand_ids[k] for k in topk_inds[i]] | |
| # label = ( | |
| # infos[idx].get("label_name") | |
| # or infos[idx].get("label") | |
| # or infos[idx].get("rel_docids") | |
| # ) | |
| # if not isinstance(label, list): | |
| # label = [label] | |
| # rel_scores = infos[idx].get("rel_scores", None) | |
| # global_idx = batch_start_idx + idx | |
| # results_dict[global_idx] = { | |
| # "prediction": cids, | |
| # "label": label, | |
| # "rel_scores": rel_scores, | |
| # } | |
| # # B. 续跑样本:从 mid 继续到 last | |
| # if len(cont_indices) > 0: | |
| # interm = getattr(out_mid, "intermediate_state", None) | |
| # hs = interm["hidden_states"].detach() | |
| # am = interm["attention_mask"].detach() | |
| # pos = interm["position_ids"].detach() | |
| # vm = interm.get("vision_mask", None) | |
| # tm = interm.get("text_mask", None) | |
| # next_layer = int(interm["next_layer_idx"]) | |
| # resume_state_subset = { | |
| # "hidden_states": hs[cont_indices], | |
| # "attention_mask": am[cont_indices], | |
| # "position_ids": pos[:, cont_indices, :], | |
| # "vision_mask": vm[cont_indices] if vm is not None else None, | |
| # "text_mask": tm[cont_indices] if tm is not None else None, | |
| # "next_layer_idx": next_layer, | |
| # } | |
| # if isinstance(aop_cfg, dict) and aop_cfg: | |
| # aop_resume = dict(aop_cfg) | |
| # aop_resume["enabled"] = bool(_orig_enabled and side_enable) | |
| # setattr(model.encoder, "aop_prune_config", aop_resume) | |
| # # 计时:mid -> last | |
| # if profile_enabled: | |
| # torch.cuda.synchronize() | |
| # t0_tail = time.perf_counter() | |
| # with torch.no_grad(), torch.autocast( | |
| # device_type="cuda", dtype=torch.bfloat16, enabled=True | |
| # ): | |
| # out_last = model.encoder( | |
| # return_dict=True, | |
| # output_hidden_states=False, | |
| # stop_at_layer=None, | |
| # resume_state=resume_state_subset, | |
| # compute_lm_head=False, | |
| # ) | |
| # if profile_enabled: | |
| # torch.cuda.synchronize() | |
| # t1_tail = time.perf_counter() | |
| # timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len( | |
| # cont_indices | |
| # ) | |
| # timing_stats["tail_num"] += len(cont_indices) | |
| # hs_last = out_last.last_hidden_state | |
| # if hs_last is None: | |
| # hs_last = out_last.hidden_states[-1] | |
| # am_last = getattr(out_last, "attention_mask", None) | |
| # if am_last is None: | |
| # am_last = resume_state_subset["attention_mask"] | |
| # reps_last_t = ( | |
| # model._pooling(hs_last, am_last) | |
| # .detach() | |
| # .to(dtype=torch.bfloat16) | |
| # ) | |
| # if use_local: | |
| # for i, idx_global in enumerate(cont_indices): | |
| # cand_local = infos[idx_global].get("cand_names", []) | |
| # if not cand_local: | |
| # cids = [] | |
| # else: | |
| # rows = [ | |
| # cand_id2row.get(str(cid), -1) | |
| # for cid in cand_local | |
| # ] | |
| # rows = [r for r in rows if r >= 0] | |
| # if len(rows) == 0: | |
| # cids = [] | |
| # else: | |
| # # CPU FP32:单个 query last × 局部 cand_last | |
| # cmat_last_np = cand_last_np[rows] # [Nc_local, D] | |
| # qry_last_np = reps_last_t[i].detach().float().cpu().numpy() | |
| # scores_vec = np.dot(qry_last_np, cmat_last_np.T) # [Nc_local] | |
| # top_k = min(200, len(rows)) | |
| # order_local = np.argsort(-scores_vec)[:top_k] | |
| # cids = [str(cand_local[o]) for o in order_local] | |
| # label = ( | |
| # infos[idx_global].get("label_name") | |
| # or infos[idx_global].get("label") | |
| # or infos[idx_global].get("rel_docids") | |
| # ) | |
| # if not isinstance(label, list): | |
| # label = [label] | |
| # rel_scores = infos[idx_global].get("rel_scores", None) | |
| # global_idx = batch_start_idx + idx_global | |
| # results_dict[global_idx] = { | |
| # "prediction": cids, | |
| # "label": label, | |
| # "rel_scores": rel_scores, | |
| # } | |
| # else: | |
| # # CPU FP32:所有续跑 query last × 全量 cand_last | |
| # reps_last_np = reps_last_t.detach().float().cpu().numpy() # [B_cont, D] | |
| # scores_last = np.dot(reps_last_np, cand_last_np.T) # [B_cont, Nc] | |
| # top_k = min(200, len(cand_ids)) | |
| # topk_inds = np.argsort(-scores_last, axis=1)[:, :top_k] # [B_cont, top_k] | |
| # for i, idx_global in enumerate(cont_indices): | |
| # cids = [cand_ids[k] for k in topk_inds[i]] | |
| # label = ( | |
| # infos[idx_global].get("label_name") | |
| # or infos[idx_global].get("label") | |
| # or infos[idx_global].get("rel_docids") | |
| # ) | |
| # if not isinstance(label, list): | |
| # label = [label] | |
| # rel_scores = infos[idx_global].get("rel_scores", None) | |
| # global_idx = batch_start_idx + idx_global | |
| # results_dict[global_idx] = { | |
| # "prediction": cids, | |
| # "label": label, | |
| # "rel_scores": rel_scores, | |
| # } | |
| # # --------------------------------------------------- | |
| # # 4. Profiling: 保存 per-query 的 topK embedding & 相似度 | |
| # # --------------------------------------------------- | |
| # if profile_enabled and is_main: | |
| # K = min(topk_emb, cand_mid_t.size(0)) | |
| # # 转到 float32 + CPU 便于写盘 | |
| # q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D] | |
| # q_last_cpu = ( | |
| # reps_last_full.detach().float().cpu() | |
| # if reps_last_full is not None | |
| # else None | |
| # ) # [B, D] | |
| # cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D] | |
| # cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D] | |
| # # 【关键修复】mid2mid 相似度(表征空间对齐) | |
| # scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc] | |
| # topk_mid_vals, topk_mid_inds = torch.topk( | |
| # scores_mid_full, k=K, dim=1 | |
| # ) | |
| # # last2last 相似度(如果有) | |
| # if q_last_cpu is not None: | |
| # scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc] | |
| # topk_last_vals, topk_last_inds = torch.topk( | |
| # scores_last_full, k=K, dim=1 | |
| # ) | |
| # else: | |
| # topk_last_vals = None | |
| # topk_last_inds = None | |
| # for i in range(B): | |
| # qid = batch_start_idx + i | |
| # rec = { | |
| # "qid": int(qid), | |
| # "early_exit": bool(exit_mask[i]), | |
| # } | |
| # if p_need_last_batch is not None: | |
| # rec["p_need_last"] = float(p_need_last_batch[i].item()) | |
| # # 【关键修复】mid2mid TopK(表征空间对齐) | |
| # mid_inds = topk_mid_inds[i].tolist() | |
| # mid_scores = topk_mid_vals[i].tolist() | |
| # rec["mid_topk_scores"] = mid_scores | |
| # rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds] | |
| # rec["mid_q_emb"] = q_mid_cpu[i].tolist() | |
| # rec["mid_cand_embs"] = cand_mid_cpu[mid_inds].tolist() | |
| # # last2last TopK(如果有) | |
| # if topk_last_inds is not None: | |
| # last_inds = topk_last_inds[i].tolist() | |
| # last_scores = topk_last_vals[i].tolist() | |
| # rec["last_topk_scores"] = last_scores | |
| # rec["last_topk_cand_ids"] = [ | |
| # cand_ids[j] for j in last_inds | |
| # ] | |
| # rec["last_q_emb"] = ( | |
| # q_last_cpu[i].tolist() if q_last_cpu is not None else None | |
| # ) | |
| # rec["last_cand_embs"] = cand_last_cpu[last_inds].tolist() | |
| # else: | |
| # rec["last_topk_scores"] = None | |
| # rec["last_topk_cand_ids"] = None | |
| # rec["last_q_emb"] = None | |
| # rec["last_cand_embs"] = None | |
| # analysis_records.append(rec) | |
| # # ===================================================== | |
| # # 5. 收集 & 保存结果 | |
| # # ===================================================== | |
| # for idx in sorted(results_dict.keys()): | |
| # pred_dicts.append(results_dict[idx]) | |
| # print_master( | |
| # f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} " | |
| # f"({stats['exit']/stats['total']:.2%})" | |
| # ) | |
| # metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] | |
| # score = RankingMetrics(metrics_to_report).evaluate(pred_dicts) | |
| # if is_main: | |
| # os.makedirs(out_dir, exist_ok=True) | |
| # with open( | |
| # os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), | |
| # "w", | |
| # ) as f: | |
| # json.dump(score, f, indent=4) | |
| # # 保存 profiling 信息 | |
| # if profile_enabled: | |
| # prof_dir = os.path.join(out_dir, "profiling") | |
| # os.makedirs(prof_dir, exist_ok=True) | |
| # # 时间统计:转换成均值(秒/样本) | |
| # mid_avg = ( | |
| # timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"]) | |
| # ) | |
| # tail_avg = ( | |
| # timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"]) | |
| # ) | |
| # timing_out = { | |
| # "mid_time_sum": timing_stats["mid_time_sum"], | |
| # "mid_num": timing_stats["mid_num"], | |
| # "tail_time_sum": timing_stats["tail_time_sum"], | |
| # "tail_num": timing_stats["tail_num"], | |
| # "avg_mid_time_per_query_sec": mid_avg, | |
| # "avg_tail_time_per_cont_query_sec": tail_avg, | |
| # "num_exit": int(stats["exit"]), | |
| # "num_total": int(stats["total"]), | |
| # } | |
| # with open( | |
| # os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w" | |
| # ) as f: | |
| # json.dump(timing_out, f, indent=2) | |
| # # embedding + 相似度记录(JSONL) | |
| # embed_path = os.path.join( | |
| # prof_dir, f"{dataset_name}_embeds.jsonl" | |
| # ) | |
| # with open(embed_path, "w") as f: | |
| # for rec in analysis_records: | |
| # f.write(json.dumps(rec) + "\n") | |
| # print_master( | |
| # f"[PROFILE] Saved timing to {prof_dir}, " | |
| # f"embeddings to {embed_path}" | |
| # ) | |
| # elapsed = time.time() - start_time | |
| # return score, elapsed | |
| # # =========================== | |
| # # Helper Functions (Pre-Computation) | |
| # # =========================== | |
| # def encode_candidates_both_layers( | |
| # model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, | |
| # model_args: ModelArguments, full_dataset: Dataset, mid_layer: int, | |
| # ): | |
| # local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| # model.eval() | |
| # all_mid, all_last, all_ids = [], [], [] | |
| # with torch.no_grad(): | |
| # for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0): | |
| # inputs = batch_to_device(inputs, training_args.device) | |
| # # 强制关闭 Cand 侧的 AOP | |
| # aop_cfg = getattr(model.encoder, "aop_prune_config", None) | |
| # _orig = None | |
| # if isinstance(aop_cfg, dict): | |
| # _orig = aop_cfg.get("enabled", False) | |
| # aop_cfg["enabled"] = False | |
| # setattr(model.encoder, "aop_prune_config", aop_cfg) | |
| # with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): | |
| # out = model.encoder(**inputs, return_dict=True, output_hidden_states=True, stop_at_layer=None) | |
| # if isinstance(aop_cfg, dict) and _orig is not None: | |
| # aop_cfg["enabled"] = _orig # Restore | |
| # mid_hs = out.hidden_states[mid_layer] | |
| # last_hs = out.hidden_states[-1] | |
| # am = inputs.get("attention_mask", None) | |
| # if am is not None and am.device != mid_hs.device: am = am.to(mid_hs.device) | |
| # reps_mid = model._pooling(mid_hs, am) | |
| # reps_last = model._pooling(last_hs, am) | |
| # all_mid.append(reps_mid.detach().float().cpu()) | |
| # all_last.append(reps_last.detach().float().cpu()) | |
| # all_ids.extend([info["cand_name"] for info in dataset_info]) | |
| # if not all_mid: return np.array([]), np.array([]), [] | |
| # return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids | |
| # # =========================== | |
| # # Main | |
| # # =========================== | |
| # def main(): | |
| # if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): | |
| # dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) | |
| # local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| # parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) | |
| # model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| # ee_cfg = get_env_ee_config() | |
| # hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| # if not getattr(model_args, "model_backbone", None): | |
| # model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) | |
| # setattr(model_args, 'model_backbone', model_backbone) | |
| # if local_rank == 0: | |
| # processor = load_processor(model_args, data_args) | |
| # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) | |
| # if torch.distributed.is_initialized(): torch.distributed.barrier() | |
| # if local_rank != 0: | |
| # processor = load_processor(model_args, data_args) | |
| # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) | |
| # model.eval() | |
| # model = model.to(training_args.device, dtype=torch.bfloat16) | |
| # # 调试:检查模型配置 | |
| # print_master(f"[DEBUG] Model normalize={model.normalize}, pooling={model.pooling}, temperature={getattr(model, 'temperature', 'N/A')}") | |
| # # AOP 配置注入 | |
| # aop_cfg = get_env_aop_config() | |
| # if aop_cfg["enabled"]: | |
| # setattr(model.encoder, "aop_prune_config", aop_cfg) | |
| # model.set_inference_layers(qry_layers=None, tgt_layers=None) | |
| # print_master(f"[AOP] Enabled: Layer={aop_cfg['layer_idx']}, Ratio={aop_cfg['keep_ratio']}") | |
| # # 加载分类器 | |
| # classifier = None | |
| # if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]: | |
| # classifier_path = ee_cfg['classifier_path'] | |
| # print_master(f"[EE] Loading Classifier from {classifier_path}...") | |
| # # 【关键】使用27维特征,与训练代码完全一致 | |
| # classifier = EarlyExitClassifier(input_dim=27, hidden_dim=64) | |
| # state_dict = None | |
| # # 尝试多种加载方式 | |
| # if os.path.isdir(classifier_path): | |
| # # 方式1: safetensors 格式(Trainer 默认) | |
| # safetensors_file = os.path.join(classifier_path, "model.safetensors") | |
| # if os.path.exists(safetensors_file): | |
| # from safetensors.torch import load_file | |
| # state_dict = load_file(safetensors_file) | |
| # print_master(f"[EE] Loaded from model.safetensors") | |
| # else: | |
| # # 方式2: pytorch_model.bin | |
| # pt_file = os.path.join(classifier_path, "pytorch_model.bin") | |
| # if os.path.exists(pt_file): | |
| # state_dict = torch.load(pt_file, map_location=training_args.device) | |
| # print_master(f"[EE] Loaded from pytorch_model.bin") | |
| # else: | |
| # # 方式3: 独立的 .pt 文件 | |
| # layer_idx = ee_cfg.get('layer', 12) | |
| # pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt") | |
| # if os.path.exists(pt_file): | |
| # state_dict = torch.load(pt_file, map_location=training_args.device) | |
| # print_master(f"[EE] Loaded from {os.path.basename(pt_file)}") | |
| # else: | |
| # raise FileNotFoundError(f"Cannot find classifier weights in {classifier_path}") | |
| # elif os.path.isfile(classifier_path): | |
| # # 方式4: 直接指定文件 | |
| # if classifier_path.endswith('.safetensors'): | |
| # from safetensors.torch import load_file | |
| # state_dict = load_file(classifier_path) | |
| # print_master(f"[EE] Loaded from .safetensors file") | |
| # else: | |
| # state_dict = torch.load(classifier_path, map_location=training_args.device) | |
| # print_master(f"[EE] Loaded from .pt file") | |
| # else: | |
| # raise FileNotFoundError(f"Classifier path not found: {classifier_path}") | |
| # classifier.load_state_dict(state_dict) | |
| # classifier.to(training_args.device) | |
| # classifier.eval() | |
| # print_master(f"[EE] Classifier loaded successfully. Threshold={ee_cfg['threshold']}") | |
| # with open(data_args.dataset_config, 'r') as yaml_file: | |
| # dataset_configs = yaml.safe_load(yaml_file) | |
| # for dataset_name, task_config in dataset_configs.items(): | |
| # if dist.is_initialized(): dist.barrier() | |
| # print_master(f"\n--- Evaluating {dataset_name} ---") | |
| # if data_args.data_basedir: | |
| # for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: | |
| # if task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) | |
| # full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) | |
| # full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) | |
| # mid_layer = int(ee_cfg["layer"]) | |
| # cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}") | |
| # cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast") | |
| # # 【关键修复】强制重新生成candidates以确保normalize设置一致 | |
| # force_regenerate = False # 已经重新生成过,设为False避免每次都重新计算 | |
| # # 预计算 Candidates | |
| # if force_regenerate or (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)): | |
| # print_master(f"[INFO] Regenerating candidates with normalize={model.normalize}...") | |
| # eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") | |
| # eval_cand_loader = DataLoader(full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) | |
| # cand_mid, cand_last, cand_ids = encode_candidates_both_layers(model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer) | |
| # if local_rank == 0: | |
| # with open(cand_mid_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f) | |
| # with open(cand_last_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f) | |
| # if dist.is_initialized(): dist.barrier() | |
| # if local_rank == 0: | |
| # with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f) | |
| # with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f) | |
| # # 【关键修复】从task_config读取eval_type,决定global/local ranking | |
| # rank_global = task_config.get("eval_type", "global") == "global" | |
| # print_master(f"[{dataset_name}] Eval type: {'global' if rank_global else 'local'} ranking") | |
| # run_early_exit_queries( | |
| # model, classifier, processor, model_args, data_args, training_args, | |
| # full_eval_qry_dataset, cand_mid_dict, cand_last_dict, ee_cfg, dataset_name, data_args.encode_output_path, | |
| # global_ranking=rank_global | |
| # ) | |
| # if dist.is_initialized(): dist.barrier() | |
| # if __name__ == '__main__': | |
| # main() | |
| ####################################################################################################### | |
| #解决CPU&GPU计算精度问题 | |
| import datetime | |
| import logging | |
| import json | |
| import random | |
| import time | |
| import numpy as np | |
| import os | |
| import pickle | |
| import sys | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn.functional as F | |
| import yaml | |
| import transformers | |
| import math | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from transformers import HfArgumentParser, AutoConfig, AutoTokenizer | |
| from datasets import Dataset, concatenate_datasets | |
| from datasets.distributed import split_dataset_by_node | |
| from src.arguments import ModelArguments, DataArguments, TrainingArguments | |
| from src.data.collator.eval_collator import MultimodalEvalDataCollator | |
| from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset | |
| from src.eval_utils.metrics import RankingMetrics | |
| from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel | |
| from src.model.processor import get_backbone_name, load_processor, COLPALI | |
| from src.utils import batch_to_device, print_rank, print_master | |
| # 引入分类器 | |
| from src.classifier_utils_V2 import EarlyExitClassifier | |
| logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # =========================== | |
| # Per-dataset thresholds(示例) | |
| # key 必须和 dataset_config YAML 里的 dataset_name 完全一致 | |
| # 未出现在此表的 dataset 使用 EE_THRESHOLD 作为默认阈值 | |
| # =========================== | |
| PER_DATASET_THRESHOLDS = { | |
| # 举例(自己根据离线 sweep 的结果填写): | |
| "CIRR": 0, | |
| "EDIS": 5e-16, | |
| "FashionIQ": 0, | |
| "MSCOCO_i2t": 1e-16, | |
| "MSCOCO_t2i": 1e-5, | |
| "NIGHTS": 1, | |
| "OVEN": 1e-11, | |
| "VisDial": 1e-11, | |
| "VisualNews_i2t": 5e-28, | |
| "VisualNews_t2i": 1e-16, | |
| "WebQA": 5e-11, | |
| "Wiki-SS-NQ": 1e-16, | |
| } | |
| # =========================== | |
| # Helper Functions (AOP Config) | |
| # =========================== | |
| def _parse_bool(v: str, default=False): | |
| if v is None: return default | |
| v = v.strip().lower() | |
| return v in {"1","true","yes","y","t","on"} | |
| def _parse_float(v: str, default=None): | |
| try: return float(v) if v is not None else default | |
| except: return default | |
| def _parse_int(v: str, default=None): | |
| try: return int(v) if v is not None else default | |
| except: return default | |
| def get_env_aop_config(): | |
| enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) | |
| apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() | |
| layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) | |
| mode = os.environ.get("AOP_MODE", "delta").strip().lower() | |
| # Parameters | |
| delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) | |
| khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) | |
| keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) | |
| min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) | |
| use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) | |
| # Specific | |
| prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) | |
| prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) | |
| keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) | |
| keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) | |
| selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() | |
| attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() | |
| if layer_idx is None and enabled: | |
| enabled = False | |
| return { | |
| "enabled": enabled, "apply_to": apply_to, "layer_idx": layer_idx, "mode": mode, | |
| "delta": delta, "K_hat": khat, "keep_ratio": keep_ratio, "min_keep": min_keep, | |
| "use_bias": use_bias, "prune_vision": prune_vision, "prune_text": prune_text, | |
| "keep_ratio_vision": keep_ratio_v, "keep_ratio_text": keep_ratio_t, | |
| "selection": selection, "attn_agg": attn_agg, | |
| "margin_mid": None # Simplified | |
| } | |
| def get_env_ee_config(): | |
| ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"} | |
| layer = int(os.environ.get("EE_LAYER", "12")) | |
| method = os.environ.get("EE_METHOD", "classifier").strip().lower() | |
| threshold = float(os.environ.get("EE_THRESHOLD", "0.8")) | |
| classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "") | |
| return dict(enabled=ee_enabled, layer=layer, method=method, threshold=threshold, classifier_path=classifier_path) | |
| def pad_dataset_to_divisible(dataset, world_size): | |
| num_samples = len(dataset) | |
| if num_samples % world_size == 0: return dataset, num_samples | |
| num_to_add = world_size - (num_samples % world_size) | |
| padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) | |
| return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add | |
| # =========================== | |
| # Core Inference Function | |
| # =========================== | |
| def run_early_exit_queries( | |
| model: MMEBModel, | |
| classifier: EarlyExitClassifier, | |
| processor, | |
| model_args: ModelArguments, | |
| data_args: DataArguments, | |
| training_args: TrainingArguments, | |
| qry_dataset: Dataset, | |
| cand_mid_dict: dict, | |
| cand_last_dict: dict, | |
| ee_cfg: dict, | |
| dataset_name: str, | |
| out_dir: str, | |
| global_ranking: bool = True, | |
| ): | |
| device = training_args.device | |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| is_main = (not dist.is_initialized()) or (local_rank == 0) | |
| # ========================== | |
| # Profiling 配置 | |
| # ========================== | |
| profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in { | |
| "1", "true", "yes", "on", "y", "t" | |
| } | |
| topk_emb = int(os.environ.get("EE_TOPK_EMB", "5")) | |
| # 运行时间统计(按样本数加权) | |
| timing_stats = { | |
| "mid_time_sum": 0.0, # encoder 到 mid 的总时间 * 样本数 | |
| "mid_num": 0, # 样本数 | |
| "tail_time_sum": 0.0, # mid->last 续跑部分的总时间 * 样本数 | |
| "tail_num": 0, # 续跑样本数 | |
| } | |
| # embedding + 相似度记录(仅 rank0 保存) | |
| analysis_records = [] if (profile_enabled and is_main) else None | |
| # 1. 准备 Candidates | |
| cand_ids = list(cand_mid_dict.keys()) | |
| cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) | |
| cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) | |
| # 新增:作为检索打分使用的 FP32 Numpy 向量 | |
| cand_mid_np = cand_mid # [Nc, D], float32 | |
| cand_last_np = cand_last # [Nc, D], float32 | |
| cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) | |
| cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16) | |
| collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") | |
| loader = DataLoader( | |
| qry_dataset, | |
| batch_size=training_args.per_device_eval_batch_size, | |
| collate_fn=collator, | |
| num_workers=training_args.dataloader_num_workers | |
| ) | |
| pred_dicts = [] | |
| stats = {"exit": 0, "total": 0} | |
| threshold = float(ee_cfg["threshold"]) | |
| method = ee_cfg["method"] | |
| target_layer_idx = int(ee_cfg["layer"]) | |
| # 结果顺序 | |
| results_dict = {} | |
| global_sample_idx = 0 | |
| # Local vs Global ranking | |
| use_local = (not global_ranking) | |
| if use_local: | |
| print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)") | |
| cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} | |
| else: | |
| print_master(f"[INFO] Using GLOBAL ranking (full library)") | |
| # --- AOP 配置初始化 --- | |
| aop_cfg = getattr(model.encoder, "aop_prune_config", None) | |
| _orig_enabled = None | |
| side_enable = True | |
| if isinstance(aop_cfg, dict) and aop_cfg: | |
| _orig_enabled = aop_cfg.get("enabled", False) | |
| apply_to = aop_cfg.get("apply_to", "qry") | |
| side_enable = (apply_to == "both") or (apply_to == "qry") | |
| model.eval() | |
| if classifier: | |
| classifier.eval() | |
| classifier.to(device) | |
| start_time = time.time() | |
| for inputs, infos in tqdm( | |
| loader, | |
| desc=f"[EE+AOP] {dataset_name} (tau={threshold})", | |
| disable=local_rank > 0, | |
| ): | |
| inputs = batch_to_device(inputs, device) | |
| B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1 | |
| batch_start_idx = global_sample_idx | |
| global_sample_idx += B | |
| stats["total"] += B | |
| # --------------------------------------------------- | |
| # 1. 前半程: Run to Mid Layer (含 AOP 动态控制) | |
| # --------------------------------------------------- | |
| orig_cfg = None | |
| if isinstance(aop_cfg, dict) and aop_cfg: | |
| orig_cfg = dict(aop_cfg) | |
| aop_layer = aop_cfg.get("layer_idx", None) | |
| aop_on_mid = bool( | |
| _orig_enabled | |
| and side_enable | |
| and (aop_layer is not None) | |
| and (aop_layer < target_layer_idx) | |
| ) | |
| aop_cfg_mid = dict(aop_cfg) | |
| aop_cfg_mid["enabled"] = aop_on_mid | |
| setattr(model.encoder, "aop_prune_config", aop_cfg_mid) | |
| # 计时:encoder 到 mid 层 | |
| if profile_enabled: | |
| torch.cuda.synchronize() | |
| t0_mid = time.perf_counter() | |
| with torch.no_grad(), torch.autocast( | |
| device_type="cuda", dtype=torch.bfloat16, enabled=True | |
| ): | |
| out_mid = model.encoder( | |
| **inputs, | |
| return_dict=True, | |
| output_hidden_states=False, | |
| stop_at_layer=target_layer_idx, | |
| compute_lm_head=False, | |
| ) | |
| if profile_enabled: | |
| torch.cuda.synchronize() | |
| t1_mid = time.perf_counter() | |
| timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B | |
| timing_stats["mid_num"] += B | |
| # 恢复 AOP 配置 | |
| if isinstance(orig_cfg, dict): | |
| setattr(model.encoder, "aop_prune_config", orig_cfg) | |
| # Hidden State & Mask | |
| hs_mid = getattr(out_mid, "last_hidden_state", None) | |
| if hs_mid is None: | |
| hs_mid = out_mid.hidden_states[-1] | |
| am_mid = getattr(out_mid, "attention_mask", None) | |
| if am_mid is None: | |
| am_mid = inputs.get("attention_mask", None) | |
| reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(dtype=torch.bfloat16) | |
| # 如果要记录最后一层 embedding,需要在 profiling 模式下额外跑一次全 forward | |
| reps_last_full = None | |
| if profile_enabled: | |
| with torch.no_grad(), torch.autocast( | |
| device_type="cuda", dtype=torch.bfloat16, enabled=True | |
| ): | |
| out_full = model.encoder( | |
| **inputs, | |
| return_dict=True, | |
| output_hidden_states=False, | |
| stop_at_layer=None, | |
| compute_lm_head=False, | |
| ) | |
| hs_last_full = getattr(out_full, "last_hidden_state", None) | |
| if hs_last_full is None: | |
| hs_last_full = out_full.hidden_states[-1] | |
| am_last_full = getattr(out_full, "attention_mask", None) | |
| if am_last_full is None: | |
| am_last_full = inputs.get("attention_mask", None) | |
| reps_last_full = ( | |
| model._pooling(hs_last_full, am_last_full) | |
| .detach() | |
| .to(dtype=torch.bfloat16) | |
| ) | |
| # --------------------------------------------------- | |
| # 2. 特征工程 + gating | |
| # --------------------------------------------------- | |
| exit_mask = np.zeros(B, dtype=bool) | |
| p_need_last_batch = None # 仅 profiling 或 debug 时保存 | |
| if method == "classifier" and classifier is not None: | |
| with torch.no_grad(): | |
| # 【关键修复】特征提取:qry_mid × cand_mid(表征空间对齐) | |
| cos_mid = reps_mid_t @ cand_mid_t.T # [B, N] | |
| backbone_ptr = ( | |
| model.module if hasattr(model, "module") else model | |
| ) | |
| temp = getattr(backbone_ptr, "temperature", 0.02) | |
| scores_mid = cos_mid / temp | |
| probs_mid = torch.softmax(scores_mid, dim=1) # [B, N] | |
| diag_cos = cos_mid.max(dim=1)[0] | |
| sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True) | |
| s2_cos = ( | |
| sorted_cos[:, 1] | |
| if sorted_cos.size(1) > 1 | |
| else sorted_cos[:, 0] | |
| ) | |
| margin_mid = diag_cos - s2_cos | |
| margin_mean = margin_mid.mean() | |
| margin_std = margin_mid.std(unbiased=False) + 1e-6 | |
| z_margin_mid = (margin_mid - margin_mean) / margin_std | |
| margin_median = margin_mid.median() | |
| mad = (margin_mid - margin_median).abs().median() + 1e-6 | |
| mad_margin_mid = (margin_mid - margin_median) / mad | |
| p1_mid = probs_mid.max(dim=1)[0] | |
| H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1) | |
| gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1) | |
| TOPK = min(16, probs_mid.size(1)) | |
| topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1) | |
| topk_mean = topk_vals.mean(dim=1) | |
| topk_std = topk_vals.std(dim=1, unbiased=False) | |
| topk_cv = topk_std / (topk_mean + 1e-6) | |
| centered = topk_vals - topk_mean.unsqueeze(1) | |
| var = (centered ** 2).mean(dim=1) + 1e-6 | |
| m4 = (centered ** 4).mean(dim=1) | |
| topk_kurt = m4 / (var ** 2) | |
| topk_med = topk_vals.median(dim=1).values | |
| row_mean_cos = cos_mid.mean(dim=1) | |
| row_med_cos = cos_mid.median(dim=1).values | |
| s1_over_mean = diag_cos - row_mean_cos | |
| s1_over_med = diag_cos - row_med_cos | |
| sorted_probs, _ = torch.sort( | |
| probs_mid, dim=1, descending=True | |
| ) | |
| p1 = sorted_probs[:, 0] | |
| p2 = ( | |
| sorted_probs[:, 1] | |
| if sorted_probs.size(1) > 1 | |
| else sorted_probs[:, 0] | |
| ) | |
| shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum( | |
| dim=1 | |
| ) | |
| shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1) | |
| R = min(10, sorted_probs.size(1)) | |
| x = torch.arange( | |
| R, device=device, dtype=sorted_probs.dtype | |
| ) | |
| x_centered = x - x.mean() | |
| denom = (x_centered ** 2).sum() | |
| y = torch.log(sorted_probs[:, :R] + 1e-6) | |
| slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom | |
| row_mean_p = probs_mid.mean(dim=1) | |
| row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6 | |
| z1 = (p1_mid - row_mean_p) / row_std_p | |
| center_p = probs_mid - row_mean_p.unsqueeze(1) | |
| m3 = (center_p ** 3).mean(dim=1) | |
| skew = m3 / (row_std_p ** 3 + 1e-6) | |
| s1_over_sk = p1_mid - skew | |
| TAIL_K = min(10, sorted_probs.size(1)) | |
| tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1) | |
| HEAD_K = min(5, sorted_probs.size(1)) | |
| head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1) | |
| mask_ratio = torch.zeros_like(diag_cos) | |
| mask_len = torch.zeros_like(diag_cos) | |
| mask_runs = torch.zeros_like(diag_cos) | |
| scalar_inputs = torch.stack( | |
| [ | |
| diag_cos, | |
| s2_cos, | |
| margin_mid, | |
| z_margin_mid, | |
| mad_margin_mid, | |
| p1_mid, | |
| H_mid, | |
| gini_mid, | |
| topk_mean, | |
| topk_std, | |
| topk_cv, | |
| topk_kurt, | |
| topk_med, | |
| s1_over_mean, | |
| s1_over_med, | |
| p1, | |
| p2, | |
| shape_H, | |
| shape_gini, | |
| slope, | |
| z1, | |
| s1_over_sk, | |
| tail_mean, | |
| head5_mean, | |
| mask_ratio, | |
| mask_len, | |
| mask_runs, | |
| ], | |
| dim=1, | |
| ) | |
| modality_idx = torch.zeros( | |
| B, dtype=torch.long, device=device | |
| ) | |
| if "pixel_values" in inputs and inputs["pixel_values"] is not None: | |
| pv = inputs["pixel_values"] | |
| if isinstance(pv, list): | |
| for i, item in enumerate(pv): | |
| if item is not None: | |
| modality_idx[i] = 1 | |
| elif isinstance(pv, torch.Tensor) and pv.numel() > 0: | |
| modality_idx.fill_(1) | |
| logits = classifier(scalar_inputs, modality_idx) | |
| p_need_last = torch.sigmoid(logits) # [B,1] | |
| p_need_last_batch = p_need_last.squeeze(1) # [B] | |
| should_exit = p_need_last_batch < threshold | |
| exit_mask = should_exit.cpu().numpy() | |
| if stats["total"] <= B * 3 and is_main: | |
| print_master( | |
| f"[EE Debug] Batch {stats['total']//B}: " | |
| f"p_need_last mean={p_need_last_batch.mean().item():.4f}, " | |
| f"std={p_need_last_batch.std().item():.4f}, " | |
| f"Exit Rate={exit_mask.mean():.2%}, " | |
| f"Top3 Feats: diag_cos={diag_cos.mean():.3f}, " | |
| f"margin={margin_mid.mean():.3f}, H={H_mid.mean():.3f}" | |
| ) | |
| stats["exit"] += exit_mask.sum() | |
| # --------------------------------------------------- | |
| # 3. 分支执行 | |
| # --------------------------------------------------- | |
| exit_indices = np.where(exit_mask)[0] | |
| cont_indices = np.where(~exit_mask)[0] | |
| # A. 早停样本:用 mid→last 检索 | |
| if len(exit_indices) > 0: | |
| reps_exit = reps_mid_t[exit_indices] | |
| if use_local: | |
| for i, idx in enumerate(exit_indices): | |
| cand_local = infos[idx].get("cand_names", []) | |
| if not cand_local: | |
| cids = [] | |
| else: | |
| rows = [ | |
| cand_id2row.get(str(cid), -1) | |
| for cid in cand_local | |
| ] | |
| rows = [r for r in rows if r >= 0] | |
| if len(rows) == 0: | |
| cids = [] | |
| else: | |
| # CPU FP32:单个 query mid × 局部 cand_mid | |
| cmat_np = cand_mid_np[rows] # [Nc_local, D] | |
| qry_np = reps_exit[i].detach().float().cpu().numpy() # [D] | |
| scores_vec = np.dot(qry_np, cmat_np.T) # [Nc_local] | |
| top_k = min(200, len(rows)) | |
| order_local = np.argsort(-scores_vec)[:top_k] | |
| cids = [str(cand_local[o]) for o in order_local] | |
| label = ( | |
| infos[idx].get("label_name") | |
| or infos[idx].get("label") | |
| or infos[idx].get("rel_docids") | |
| ) | |
| if not isinstance(label, list): | |
| label = [label] | |
| rel_scores = infos[idx].get("rel_scores", None) | |
| global_idx = batch_start_idx + idx | |
| results_dict[global_idx] = { | |
| "prediction": cids, | |
| "label": label, | |
| "rel_scores": rel_scores, | |
| } | |
| else: | |
| # 使用 CPU FP32 Numpy:reps_exit_np × cand_mid_np.T | |
| reps_exit_np = reps_exit.detach().float().cpu().numpy() # [B_exit, D] | |
| scores_full = np.dot(reps_exit_np, cand_mid_np.T) # [B_exit, Nc] | |
| top_k = min(200, len(cand_ids)) | |
| # 全排序后截到 top_k | |
| topk_inds = np.argsort(-scores_full, axis=1)[:, :top_k] # [B_exit, top_k] | |
| for i, idx in enumerate(exit_indices): | |
| cids = [cand_ids[k] for k in topk_inds[i]] | |
| label = ( | |
| infos[idx].get("label_name") | |
| or infos[idx].get("label") | |
| or infos[idx].get("rel_docids") | |
| ) | |
| if not isinstance(label, list): | |
| label = [label] | |
| rel_scores = infos[idx].get("rel_scores", None) | |
| global_idx = batch_start_idx + idx | |
| results_dict[global_idx] = { | |
| "prediction": cids, | |
| "label": label, | |
| "rel_scores": rel_scores, | |
| } | |
| # B. 续跑样本:从 mid 继续到 last | |
| if len(cont_indices) > 0: | |
| interm = getattr(out_mid, "intermediate_state", None) | |
| hs = interm["hidden_states"].detach() | |
| am = interm["attention_mask"].detach() | |
| pos = interm["position_ids"].detach() | |
| vm = interm.get("vision_mask", None) | |
| tm = interm.get("text_mask", None) | |
| next_layer = int(interm["next_layer_idx"]) | |
| resume_state_subset = { | |
| "hidden_states": hs[cont_indices], | |
| "attention_mask": am[cont_indices], | |
| "position_ids": pos[:, cont_indices, :], | |
| "vision_mask": vm[cont_indices] if vm is not None else None, | |
| "text_mask": tm[cont_indices] if tm is not None else None, | |
| "next_layer_idx": next_layer, | |
| } | |
| if isinstance(aop_cfg, dict) and aop_cfg: | |
| aop_resume = dict(aop_cfg) | |
| aop_resume["enabled"] = bool(_orig_enabled and side_enable) | |
| setattr(model.encoder, "aop_prune_config", aop_resume) | |
| # 计时:mid -> last | |
| if profile_enabled: | |
| torch.cuda.synchronize() | |
| t0_tail = time.perf_counter() | |
| with torch.no_grad(), torch.autocast( | |
| device_type="cuda", dtype=torch.bfloat16, enabled=True | |
| ): | |
| out_last = model.encoder( | |
| return_dict=True, | |
| output_hidden_states=False, | |
| stop_at_layer=None, | |
| resume_state=resume_state_subset, | |
| compute_lm_head=False, | |
| ) | |
| if profile_enabled: | |
| torch.cuda.synchronize() | |
| t1_tail = time.perf_counter() | |
| timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len( | |
| cont_indices | |
| ) | |
| timing_stats["tail_num"] += len(cont_indices) | |
| hs_last = out_last.last_hidden_state | |
| if hs_last is None: | |
| hs_last = out_last.hidden_states[-1] | |
| am_last = getattr(out_last, "attention_mask", None) | |
| if am_last is None: | |
| am_last = resume_state_subset["attention_mask"] | |
| reps_last_t = ( | |
| model._pooling(hs_last, am_last) | |
| .detach() | |
| .to(dtype=torch.bfloat16) | |
| ) | |
| if use_local: | |
| for i, idx_global in enumerate(cont_indices): | |
| cand_local = infos[idx_global].get("cand_names", []) | |
| if not cand_local: | |
| cids = [] | |
| else: | |
| rows = [ | |
| cand_id2row.get(str(cid), -1) | |
| for cid in cand_local | |
| ] | |
| rows = [r for r in rows if r >= 0] | |
| if len(rows) == 0: | |
| cids = [] | |
| else: | |
| # CPU FP32:单个 query last × 局部 cand_last | |
| cmat_last_np = cand_last_np[rows] # [Nc_local, D] | |
| qry_last_np = reps_last_t[i].detach().float().cpu().numpy() | |
| scores_vec = np.dot(qry_last_np, cmat_last_np.T) # [Nc_local] | |
| top_k = min(200, len(rows)) | |
| order_local = np.argsort(-scores_vec)[:top_k] | |
| cids = [str(cand_local[o]) for o in order_local] | |
| label = ( | |
| infos[idx_global].get("label_name") | |
| or infos[idx_global].get("label") | |
| or infos[idx_global].get("rel_docids") | |
| ) | |
| if not isinstance(label, list): | |
| label = [label] | |
| rel_scores = infos[idx_global].get("rel_scores", None) | |
| global_idx = batch_start_idx + idx_global | |
| results_dict[global_idx] = { | |
| "prediction": cids, | |
| "label": label, | |
| "rel_scores": rel_scores, | |
| } | |
| else: | |
| # CPU FP32:所有续跑 query last × 全量 cand_last | |
| reps_last_np = reps_last_t.detach().float().cpu().numpy() # [B_cont, D] | |
| scores_last = np.dot(reps_last_np, cand_last_np.T) # [B_cont, Nc] | |
| top_k = min(200, len(cand_ids)) | |
| topk_inds = np.argsort(-scores_last, axis=1)[:, :top_k] # [B_cont, top_k] | |
| for i, idx_global in enumerate(cont_indices): | |
| cids = [cand_ids[k] for k in topk_inds[i]] | |
| label = ( | |
| infos[idx_global].get("label_name") | |
| or infos[idx_global].get("label") | |
| or infos[idx_global].get("rel_docids") | |
| ) | |
| if not isinstance(label, list): | |
| label = [label] | |
| rel_scores = infos[idx_global].get("rel_scores", None) | |
| global_idx = batch_start_idx + idx_global | |
| results_dict[global_idx] = { | |
| "prediction": cids, | |
| "label": label, | |
| "rel_scores": rel_scores, | |
| } | |
| # --------------------------------------------------- | |
| # 4. Profiling: 保存 per-query 的 topK embedding & 相似度 | |
| # --------------------------------------------------- | |
| if profile_enabled and is_main: | |
| K = min(topk_emb, cand_mid_t.size(0)) | |
| # 转到 float32 + CPU 便于写盘 | |
| q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D] | |
| q_last_cpu = ( | |
| reps_last_full.detach().float().cpu() | |
| if reps_last_full is not None | |
| else None | |
| ) # [B, D] | |
| cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D] | |
| cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D] | |
| # 【关键修复】mid2mid 相似度(表征空间对齐) | |
| scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc] | |
| topk_mid_vals, topk_mid_inds = torch.topk( | |
| scores_mid_full, k=K, dim=1 | |
| ) | |
| # last2last 相似度(如果有) | |
| if q_last_cpu is not None: | |
| scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc] | |
| topk_last_vals, topk_last_inds = torch.topk( | |
| scores_last_full, k=K, dim=1 | |
| ) | |
| else: | |
| topk_last_vals = None | |
| topk_last_inds = None | |
| for i in range(B): | |
| qid = batch_start_idx + i | |
| rec = { | |
| "qid": int(qid), | |
| "early_exit": bool(exit_mask[i]), | |
| } | |
| if p_need_last_batch is not None: | |
| rec["p_need_last"] = float(p_need_last_batch[i].item()) | |
| # 【关键修复】mid2mid TopK(表征空间对齐) | |
| mid_inds = topk_mid_inds[i].tolist() | |
| mid_scores = topk_mid_vals[i].tolist() | |
| rec["mid_topk_scores"] = mid_scores | |
| rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds] | |
| rec["mid_q_emb"] = q_mid_cpu[i].tolist() | |
| rec["mid_cand_embs"] = cand_mid_cpu[mid_inds].tolist() | |
| # last2last TopK(如果有) | |
| if topk_last_inds is not None: | |
| last_inds = topk_last_inds[i].tolist() | |
| last_scores = topk_last_vals[i].tolist() | |
| rec["last_topk_scores"] = last_scores | |
| rec["last_topk_cand_ids"] = [ | |
| cand_ids[j] for j in last_inds | |
| ] | |
| rec["last_q_emb"] = ( | |
| q_last_cpu[i].tolist() if q_last_cpu is not None else None | |
| ) | |
| rec["last_cand_embs"] = cand_last_cpu[last_inds].tolist() | |
| else: | |
| rec["last_topk_scores"] = None | |
| rec["last_topk_cand_ids"] = None | |
| rec["last_q_emb"] = None | |
| rec["last_cand_embs"] = None | |
| analysis_records.append(rec) | |
| # ===================================================== | |
| # 5. 收集 & 保存结果 | |
| # ===================================================== | |
| for idx in sorted(results_dict.keys()): | |
| pred_dicts.append(results_dict[idx]) | |
| print_master( | |
| f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} " | |
| f"({stats['exit']/stats['total']:.2%})" | |
| ) | |
| metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] | |
| score = RankingMetrics(metrics_to_report).evaluate(pred_dicts) | |
| if is_main: | |
| os.makedirs(out_dir, exist_ok=True) | |
| with open( | |
| os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), | |
| "w", | |
| ) as f: | |
| json.dump(score, f, indent=4) | |
| # 保存 profiling 信息 | |
| if profile_enabled: | |
| prof_dir = os.path.join(out_dir, "profiling") | |
| os.makedirs(prof_dir, exist_ok=True) | |
| # 时间统计:转换成均值(秒/样本) | |
| mid_avg = ( | |
| timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"]) | |
| ) | |
| tail_avg = ( | |
| timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"]) | |
| ) | |
| timing_out = { | |
| "mid_time_sum": timing_stats["mid_time_sum"], | |
| "mid_num": timing_stats["mid_num"], | |
| "tail_time_sum": timing_stats["tail_time_sum"], | |
| "tail_num": timing_stats["tail_num"], | |
| "avg_mid_time_per_query_sec": mid_avg, | |
| "avg_tail_time_per_cont_query_sec": tail_avg, | |
| "num_exit": int(stats["exit"]), | |
| "num_total": int(stats["total"]), | |
| } | |
| with open( | |
| os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w" | |
| ) as f: | |
| json.dump(timing_out, f, indent=2) | |
| # embedding + 相似度记录(JSONL) | |
| embed_path = os.path.join( | |
| prof_dir, f"{dataset_name}_embeds.jsonl" | |
| ) | |
| with open(embed_path, "w") as f: | |
| for rec in analysis_records: | |
| f.write(json.dumps(rec) + "\n") | |
| print_master( | |
| f"[PROFILE] Saved timing to {prof_dir}, " | |
| f"embeddings to {embed_path}" | |
| ) | |
| elapsed = time.time() - start_time | |
| return score, elapsed | |
| # =========================== | |
| # Helper Functions (Pre-Computation) | |
| # =========================== | |
| def encode_candidates_both_layers( | |
| model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, | |
| model_args: ModelArguments, full_dataset: Dataset, mid_layer: int, | |
| ): | |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| model.eval() | |
| all_mid, all_last, all_ids = [], [], [] | |
| with torch.no_grad(): | |
| for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0): | |
| inputs = batch_to_device(inputs, training_args.device) | |
| # 强制关闭 Cand 侧的 AOP | |
| aop_cfg = getattr(model.encoder, "aop_prune_config", None) | |
| _orig = None | |
| if isinstance(aop_cfg, dict): | |
| _orig = aop_cfg.get("enabled", False) | |
| aop_cfg["enabled"] = False | |
| setattr(model.encoder, "aop_prune_config", aop_cfg) | |
| with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): | |
| out = model.encoder(**inputs, return_dict=True, output_hidden_states=True, stop_at_layer=None) | |
| if isinstance(aop_cfg, dict) and _orig is not None: | |
| aop_cfg["enabled"] = _orig # Restore | |
| mid_hs = out.hidden_states[mid_layer] | |
| last_hs = out.hidden_states[-1] | |
| am = inputs.get("attention_mask", None) | |
| if am is not None and am.device != mid_hs.device: am = am.to(mid_hs.device) | |
| reps_mid = model._pooling(mid_hs, am) | |
| reps_last = model._pooling(last_hs, am) | |
| all_mid.append(reps_mid.detach().float().cpu()) | |
| all_last.append(reps_last.detach().float().cpu()) | |
| all_ids.extend([info["cand_name"] for info in dataset_info]) | |
| if not all_mid: return np.array([]), np.array([]), [] | |
| return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids | |
| # =========================== | |
| # Main | |
| # =========================== | |
| def main(): | |
| if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): | |
| dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) | |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) | |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| ee_cfg = get_env_ee_config() | |
| hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| if not getattr(model_args, "model_backbone", None): | |
| model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) | |
| setattr(model_args, 'model_backbone', model_backbone) | |
| if local_rank == 0: | |
| processor = load_processor(model_args, data_args) | |
| model = MMEBModel.load(model_args, is_trainable=False, processor=processor) | |
| if torch.distributed.is_initialized(): torch.distributed.barrier() | |
| if local_rank != 0: | |
| processor = load_processor(model_args, data_args) | |
| model = MMEBModel.load(model_args, is_trainable=False, processor=processor) | |
| model.eval() | |
| model = model.to(training_args.device, dtype=torch.bfloat16) | |
| # 调试:检查模型配置 | |
| print_master(f"[DEBUG] Model normalize={model.normalize}, pooling={model.pooling}, temperature={getattr(model, 'temperature', 'N/A')}") | |
| # AOP 配置注入 | |
| aop_cfg = get_env_aop_config() | |
| if aop_cfg["enabled"]: | |
| setattr(model.encoder, "aop_prune_config", aop_cfg) | |
| model.set_inference_layers(qry_layers=None, tgt_layers=None) | |
| print_master(f"[AOP] Enabled: Layer={aop_cfg['layer_idx']}, Ratio={aop_cfg['keep_ratio']}") | |
| # 加载分类器 | |
| classifier = None | |
| if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]: | |
| classifier_path = ee_cfg['classifier_path'] | |
| print_master(f"[EE] Loading Classifier from {classifier_path}...") | |
| # 【关键】使用27维特征,与训练代码完全一致 | |
| classifier = EarlyExitClassifier(input_dim=27, hidden_dim=64) | |
| state_dict = None | |
| # 尝试多种加载方式 | |
| if os.path.isdir(classifier_path): | |
| # 方式1: safetensors 格式(Trainer 默认) | |
| safetensors_file = os.path.join(classifier_path, "model.safetensors") | |
| if os.path.exists(safetensors_file): | |
| from safetensors.torch import load_file | |
| state_dict = load_file(safetensors_file) | |
| print_master(f"[EE] Loaded from model.safetensors") | |
| else: | |
| # 方式2: pytorch_model.bin | |
| pt_file = os.path.join(classifier_path, "pytorch_model.bin") | |
| if os.path.exists(pt_file): | |
| state_dict = torch.load(pt_file, map_location=training_args.device) | |
| print_master(f"[EE] Loaded from pytorch_model.bin") | |
| else: | |
| # 方式3: 独立的 .pt 文件 | |
| layer_idx = ee_cfg.get('layer', 12) | |
| pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt") | |
| if os.path.exists(pt_file): | |
| state_dict = torch.load(pt_file, map_location=training_args.device) | |
| print_master(f"[EE] Loaded from {os.path.basename(pt_file)}") | |
| else: | |
| raise FileNotFoundError(f"Cannot find classifier weights in {classifier_path}") | |
| elif os.path.isfile(classifier_path): | |
| # 方式4: 直接指定文件 | |
| if classifier_path.endswith('.safetensors'): | |
| from safetensors.torch import load_file | |
| state_dict = load_file(classifier_path) | |
| print_master(f"[EE] Loaded from .safetensors file") | |
| else: | |
| state_dict = torch.load(classifier_path, map_location=training_args.device) | |
| print_master(f"[EE] Loaded from .pt file") | |
| else: | |
| raise FileNotFoundError(f"Classifier path not found: {classifier_path}") | |
| classifier.load_state_dict(state_dict) | |
| classifier.to(training_args.device) | |
| classifier.eval() | |
| print_master(f"[EE] Classifier loaded successfully. Threshold={ee_cfg['threshold']}") | |
| with open(data_args.dataset_config, 'r') as yaml_file: | |
| dataset_configs = yaml.safe_load(yaml_file) | |
| for dataset_name, task_config in dataset_configs.items(): | |
| if dist.is_initialized(): dist.barrier() | |
| print_master(f"\n--- Evaluating {dataset_name} ---") | |
| # ===== 为当前数据集选择专用阈值 τ_ds ===== | |
| base_tau = float(ee_cfg["threshold"]) # 环境变量 EE_THRESHOLD 作为默认 | |
| ds_tau = PER_DATASET_THRESHOLDS.get(dataset_name, base_tau) | |
| # 为当前 dataset 拷贝一份 ee_cfg,并覆盖 threshold | |
| ee_cfg_ds = dict(ee_cfg) | |
| ee_cfg_ds["threshold"] = float(ds_tau) | |
| print_master( | |
| f"[EE] Dataset '{dataset_name}' use threshold τ={ee_cfg_ds['threshold']:.6e} " | |
| f"(default={base_tau:.6e})" | |
| ) | |
| if data_args.data_basedir: | |
| for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: | |
| if task_config.get(key): | |
| task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) | |
| full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate( | |
| model_args=model_args, data_args=data_args, **task_config | |
| ) | |
| full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) | |
| mid_layer = int(ee_cfg_ds["layer"]) | |
| cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}") | |
| cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast") | |
| force_regenerate = False | |
| if force_regenerate or (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)): | |
| print_master(f"[INFO] Regenerating candidates with normalize={model.normalize}...") | |
| eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") | |
| eval_cand_loader = DataLoader( | |
| full_eval_cand_dataset, | |
| batch_size=training_args.per_device_eval_batch_size, | |
| collate_fn=eval_cand_collator, | |
| num_workers=training_args.dataloader_num_workers, | |
| ) | |
| cand_mid, cand_last, cand_ids = encode_candidates_both_layers( | |
| model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer | |
| ) | |
| if local_rank == 0: | |
| with open(cand_mid_path, "wb") as f: | |
| pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f) | |
| with open(cand_last_path, "wb") as f: | |
| pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f) | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| if local_rank == 0: | |
| with open(cand_mid_path, "rb") as f: | |
| cand_mid_dict = pickle.load(f) | |
| with open(cand_last_path, "rb") as f: | |
| cand_last_dict = pickle.load(f) | |
| rank_global = task_config.get("eval_type", "global") == "global" | |
| print_master(f"[{dataset_name}] Eval type: {'global' if rank_global else 'local'} ranking") | |
| # 关键:传 ee_cfg_ds,而不是全局 ee_cfg | |
| run_early_exit_queries( | |
| model, classifier, processor, model_args, data_args, training_args, | |
| full_eval_qry_dataset, cand_mid_dict, cand_last_dict, | |
| ee_cfg_ds, dataset_name, data_args.encode_output_path, | |
| global_ranking=rank_global, | |
| ) | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| if dist.is_initialized(): dist.barrier() | |
| if __name__ == '__main__': | |
| main() | |