####################################################################################################### #原始版本 # import datetime # import logging # import json # import random # import time # import numpy as np # import os # import pickle # import sys # import torch # import torch.distributed as dist # import torch.nn.functional as F # import yaml # import transformers # import math # from torch.utils.data import DataLoader # from tqdm import tqdm # from transformers import HfArgumentParser, AutoConfig, AutoTokenizer # from datasets import Dataset, concatenate_datasets # from datasets.distributed import split_dataset_by_node # from src.arguments import ModelArguments, DataArguments, TrainingArguments # from src.data.collator.eval_collator import MultimodalEvalDataCollator # from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset # from src.eval_utils.metrics import RankingMetrics # from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel # from src.model.processor import get_backbone_name, load_processor, COLPALI # from src.utils import batch_to_device, print_rank, print_master # # 引入分类器 # from src.classifier_utils_V2 import EarlyExitClassifier # logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') # logger = logging.getLogger(__name__) # # =========================== # # Helper Functions (AOP Config) # # =========================== # def _parse_bool(v: str, default=False): # if v is None: return default # v = v.strip().lower() # return v in {"1","true","yes","y","t","on"} # def _parse_float(v: str, default=None): # try: return float(v) if v is not None else default # except: return default # def _parse_int(v: str, default=None): # try: return int(v) if v is not None else default # except: return default # def get_env_aop_config(): # enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) # apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) # mode = os.environ.get("AOP_MODE", "delta").strip().lower() # # Parameters # delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) # khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) # keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) # min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) # use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) # # Specific # prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) # prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) # keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) # keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) # selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() # attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() # if layer_idx is None and enabled: # enabled = False # return { # "enabled": enabled, "apply_to": apply_to, "layer_idx": layer_idx, "mode": mode, # "delta": delta, "K_hat": khat, "keep_ratio": keep_ratio, "min_keep": min_keep, # "use_bias": use_bias, "prune_vision": prune_vision, "prune_text": prune_text, # "keep_ratio_vision": keep_ratio_v, "keep_ratio_text": keep_ratio_t, # "selection": selection, "attn_agg": attn_agg, # "margin_mid": None # Simplified # } # def get_env_ee_config(): # ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"} # layer = int(os.environ.get("EE_LAYER", "12")) # method = os.environ.get("EE_METHOD", "classifier").strip().lower() # threshold = float(os.environ.get("EE_THRESHOLD", "0.8")) # classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "") # return dict(enabled=ee_enabled, layer=layer, method=method, threshold=threshold, classifier_path=classifier_path) # def pad_dataset_to_divisible(dataset, world_size): # num_samples = len(dataset) # if num_samples % world_size == 0: return dataset, num_samples # num_to_add = world_size - (num_samples % world_size) # padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) # return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add # # =========================== # # Core Inference Function # # =========================== # def run_early_exit_queries( # model: MMEBModel, # classifier: EarlyExitClassifier, # processor, # model_args: ModelArguments, # data_args: DataArguments, # training_args: TrainingArguments, # qry_dataset: Dataset, # cand_mid_dict: dict, # cand_last_dict: dict, # ee_cfg: dict, # dataset_name: str, # out_dir: str, # global_ranking: bool = True, # ): # device = training_args.device # local_rank = dist.get_rank() if dist.is_initialized() else 0 # is_main = (not dist.is_initialized()) or (local_rank == 0) # # ========================== # # Profiling 配置 # # ========================== # profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in { # "1", "true", "yes", "on", "y", "t" # } # topk_emb = int(os.environ.get("EE_TOPK_EMB", "5")) # # 运行时间统计(按样本数加权) # timing_stats = { # "mid_time_sum": 0.0, # encoder 到 mid 的总时间 * 样本数 # "mid_num": 0, # 样本数 # "tail_time_sum": 0.0, # mid->last 续跑部分的总时间 * 样本数 # "tail_num": 0, # 续跑样本数 # } # # embedding + 相似度记录(仅 rank0 保存) # analysis_records = [] if (profile_enabled and is_main) else None # # 1. 准备 Candidates # cand_ids = list(cand_mid_dict.keys()) # cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) # cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) # cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) # cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16) # collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") # loader = DataLoader( # qry_dataset, # batch_size=training_args.per_device_eval_batch_size, # collate_fn=collator, # num_workers=training_args.dataloader_num_workers # ) # pred_dicts = [] # stats = {"exit": 0, "total": 0} # threshold = float(ee_cfg["threshold"]) # method = ee_cfg["method"] # target_layer_idx = int(ee_cfg["layer"]) # # 结果顺序 # results_dict = {} # global_sample_idx = 0 # # Local vs Global ranking # use_local = (not global_ranking) # if use_local: # print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)") # cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} # else: # print_master(f"[INFO] Using GLOBAL ranking (full library)") # # --- AOP 配置初始化 --- # aop_cfg = getattr(model.encoder, "aop_prune_config", None) # _orig_enabled = None # side_enable = True # if isinstance(aop_cfg, dict) and aop_cfg: # _orig_enabled = aop_cfg.get("enabled", False) # apply_to = aop_cfg.get("apply_to", "qry") # side_enable = (apply_to == "both") or (apply_to == "qry") # model.eval() # if classifier: # classifier.eval() # classifier.to(device) # start_time = time.time() # for inputs, infos in tqdm( # loader, # desc=f"[EE+AOP] {dataset_name} (tau={threshold})", # disable=local_rank > 0, # ): # inputs = batch_to_device(inputs, device) # B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1 # batch_start_idx = global_sample_idx # global_sample_idx += B # stats["total"] += B # # --------------------------------------------------- # # 1. 前半程: Run to Mid Layer (含 AOP 动态控制) # # --------------------------------------------------- # orig_cfg = None # if isinstance(aop_cfg, dict) and aop_cfg: # orig_cfg = dict(aop_cfg) # aop_layer = aop_cfg.get("layer_idx", None) # aop_on_mid = bool( # _orig_enabled # and side_enable # and (aop_layer is not None) # and (aop_layer < target_layer_idx) # ) # aop_cfg_mid = dict(aop_cfg) # aop_cfg_mid["enabled"] = aop_on_mid # setattr(model.encoder, "aop_prune_config", aop_cfg_mid) # # 计时:encoder 到 mid 层 # if profile_enabled: # torch.cuda.synchronize() # t0_mid = time.perf_counter() # with torch.no_grad(), torch.autocast( # device_type="cuda", dtype=torch.bfloat16, enabled=True # ): # out_mid = model.encoder( # **inputs, # return_dict=True, # output_hidden_states=False, # stop_at_layer=target_layer_idx, # compute_lm_head=False, # ) # if profile_enabled: # torch.cuda.synchronize() # t1_mid = time.perf_counter() # timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B # timing_stats["mid_num"] += B # # 恢复 AOP 配置 # if isinstance(orig_cfg, dict): # setattr(model.encoder, "aop_prune_config", orig_cfg) # # Hidden State & Mask # hs_mid = getattr(out_mid, "last_hidden_state", None) # if hs_mid is None: # hs_mid = out_mid.hidden_states[-1] # am_mid = getattr(out_mid, "attention_mask", None) # if am_mid is None: # am_mid = inputs.get("attention_mask", None) # reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(dtype=torch.bfloat16) # # 如果要记录最后一层 embedding,需要在 profiling 模式下额外跑一次全 forward # reps_last_full = None # if profile_enabled: # with torch.no_grad(), torch.autocast( # device_type="cuda", dtype=torch.bfloat16, enabled=True # ): # out_full = model.encoder( # **inputs, # return_dict=True, # output_hidden_states=False, # stop_at_layer=None, # compute_lm_head=False, # ) # hs_last_full = getattr(out_full, "last_hidden_state", None) # if hs_last_full is None: # hs_last_full = out_full.hidden_states[-1] # am_last_full = getattr(out_full, "attention_mask", None) # if am_last_full is None: # am_last_full = inputs.get("attention_mask", None) # reps_last_full = ( # model._pooling(hs_last_full, am_last_full) # .detach() # .to(dtype=torch.bfloat16) # ) # # --------------------------------------------------- # # 2. 特征工程 + gating # # --------------------------------------------------- # exit_mask = np.zeros(B, dtype=bool) # p_need_last_batch = None # 仅 profiling 或 debug 时保存 # if method == "classifier" and classifier is not None: # with torch.no_grad(): # # 【关键修复】特征提取:qry_mid × cand_mid(表征空间对齐) # cos_mid = reps_mid_t @ cand_mid_t.T # [B, N] # backbone_ptr = ( # model.module if hasattr(model, "module") else model # ) # temp = getattr(backbone_ptr, "temperature", 0.02) # scores_mid = cos_mid / temp # probs_mid = torch.softmax(scores_mid, dim=1) # [B, N] # diag_cos = cos_mid.max(dim=1)[0] # sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True) # s2_cos = ( # sorted_cos[:, 1] # if sorted_cos.size(1) > 1 # else sorted_cos[:, 0] # ) # margin_mid = diag_cos - s2_cos # margin_mean = margin_mid.mean() # margin_std = margin_mid.std(unbiased=False) + 1e-6 # z_margin_mid = (margin_mid - margin_mean) / margin_std # margin_median = margin_mid.median() # mad = (margin_mid - margin_median).abs().median() + 1e-6 # mad_margin_mid = (margin_mid - margin_median) / mad # p1_mid = probs_mid.max(dim=1)[0] # H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1) # gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1) # TOPK = min(16, probs_mid.size(1)) # topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1) # topk_mean = topk_vals.mean(dim=1) # topk_std = topk_vals.std(dim=1, unbiased=False) # topk_cv = topk_std / (topk_mean + 1e-6) # centered = topk_vals - topk_mean.unsqueeze(1) # var = (centered ** 2).mean(dim=1) + 1e-6 # m4 = (centered ** 4).mean(dim=1) # topk_kurt = m4 / (var ** 2) # topk_med = topk_vals.median(dim=1).values # row_mean_cos = cos_mid.mean(dim=1) # row_med_cos = cos_mid.median(dim=1).values # s1_over_mean = diag_cos - row_mean_cos # s1_over_med = diag_cos - row_med_cos # sorted_probs, _ = torch.sort( # probs_mid, dim=1, descending=True # ) # p1 = sorted_probs[:, 0] # p2 = ( # sorted_probs[:, 1] # if sorted_probs.size(1) > 1 # else sorted_probs[:, 0] # ) # shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum( # dim=1 # ) # shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1) # R = min(10, sorted_probs.size(1)) # x = torch.arange( # R, device=device, dtype=sorted_probs.dtype # ) # x_centered = x - x.mean() # denom = (x_centered ** 2).sum() # y = torch.log(sorted_probs[:, :R] + 1e-6) # slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom # row_mean_p = probs_mid.mean(dim=1) # row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6 # z1 = (p1_mid - row_mean_p) / row_std_p # center_p = probs_mid - row_mean_p.unsqueeze(1) # m3 = (center_p ** 3).mean(dim=1) # skew = m3 / (row_std_p ** 3 + 1e-6) # s1_over_sk = p1_mid - skew # TAIL_K = min(10, sorted_probs.size(1)) # tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1) # HEAD_K = min(5, sorted_probs.size(1)) # head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1) # mask_ratio = torch.zeros_like(diag_cos) # mask_len = torch.zeros_like(diag_cos) # mask_runs = torch.zeros_like(diag_cos) # scalar_inputs = torch.stack( # [ # diag_cos, # s2_cos, # margin_mid, # z_margin_mid, # mad_margin_mid, # p1_mid, # H_mid, # gini_mid, # topk_mean, # topk_std, # topk_cv, # topk_kurt, # topk_med, # s1_over_mean, # s1_over_med, # p1, # p2, # shape_H, # shape_gini, # slope, # z1, # s1_over_sk, # tail_mean, # head5_mean, # mask_ratio, # mask_len, # mask_runs, # ], # dim=1, # ) # modality_idx = torch.zeros( # B, dtype=torch.long, device=device # ) # if "pixel_values" in inputs and inputs["pixel_values"] is not None: # pv = inputs["pixel_values"] # if isinstance(pv, list): # for i, item in enumerate(pv): # if item is not None: # modality_idx[i] = 1 # elif isinstance(pv, torch.Tensor) and pv.numel() > 0: # modality_idx.fill_(1) # logits = classifier(scalar_inputs, modality_idx) # p_need_last = torch.sigmoid(logits) # [B,1] # p_need_last_batch = p_need_last.squeeze(1) # [B] # should_exit = p_need_last_batch < threshold # exit_mask = should_exit.cpu().numpy() # if stats["total"] <= B * 3 and is_main: # print_master( # f"[EE Debug] Batch {stats['total']//B}: " # f"p_need_last mean={p_need_last_batch.mean().item():.4f}, " # f"std={p_need_last_batch.std().item():.4f}, " # f"Exit Rate={exit_mask.mean():.2%}, " # f"Top3 Feats: diag_cos={diag_cos.mean():.3f}, " # f"margin={margin_mid.mean():.3f}, H={H_mid.mean():.3f}" # ) # stats["exit"] += exit_mask.sum() # # --------------------------------------------------- # # 3. 分支执行 # # --------------------------------------------------- # exit_indices = np.where(exit_mask)[0] # cont_indices = np.where(~exit_mask)[0] # # A. 早停样本:用 mid→last 检索 # if len(exit_indices) > 0: # reps_exit = reps_mid_t[exit_indices] # if use_local: # for i, idx in enumerate(exit_indices): # cand_local = infos[idx].get("cand_names", []) # if not cand_local: # cids = [] # else: # rows = [ # cand_id2row.get(str(cid), -1) # for cid in cand_local # ] # rows = [r for r in rows if r >= 0] # if len(rows) == 0: # cids = [] # else: # # 【关键修复】早停检索:qry_mid × cand_mid(表征空间对齐) # cmat_t = cand_mid_t[rows] # scores_vec = (reps_exit[i : i + 1] @ cmat_t.T)[0] # order_local = ( # torch.argsort( # scores_vec, # dim=0, # descending=True, # ) # .cpu() # .numpy() # ) # cids = [str(cand_local[o]) for o in order_local] # label = ( # infos[idx].get("label_name") # or infos[idx].get("label") # or infos[idx].get("rel_docids") # ) # if not isinstance(label, list): # label = [label] # rel_scores = infos[idx].get("rel_scores", None) # global_idx = batch_start_idx + idx # results_dict[global_idx] = { # "prediction": cids, # "label": label, # "rel_scores": rel_scores, # } # else: # # 【关键修复】早停检索:qry_mid × cand_mid(表征空间对齐) # scores_full = reps_exit @ cand_mid_t.T # _, topk_inds = torch.topk( # scores_full, k=min(200, len(cand_ids)), dim=1 # ) # topk_inds = topk_inds.cpu().numpy() # for i, idx in enumerate(exit_indices): # cids = [cand_ids[k] for k in topk_inds[i]] # label = ( # infos[idx].get("label_name") # or infos[idx].get("label") # or infos[idx].get("rel_docids") # ) # if not isinstance(label, list): # label = [label] # rel_scores = infos[idx].get("rel_scores", None) # global_idx = batch_start_idx + idx # results_dict[global_idx] = { # "prediction": cids, # "label": label, # "rel_scores": rel_scores, # } # # B. 续跑样本:从 mid 继续到 last # if len(cont_indices) > 0: # interm = getattr(out_mid, "intermediate_state", None) # hs = interm["hidden_states"].detach() # am = interm["attention_mask"].detach() # pos = interm["position_ids"].detach() # vm = interm.get("vision_mask", None) # tm = interm.get("text_mask", None) # next_layer = int(interm["next_layer_idx"]) # resume_state_subset = { # "hidden_states": hs[cont_indices], # "attention_mask": am[cont_indices], # "position_ids": pos[:, cont_indices, :], # "vision_mask": vm[cont_indices] if vm is not None else None, # "text_mask": tm[cont_indices] if tm is not None else None, # "next_layer_idx": next_layer, # } # if isinstance(aop_cfg, dict) and aop_cfg: # aop_resume = dict(aop_cfg) # aop_resume["enabled"] = bool(_orig_enabled and side_enable) # setattr(model.encoder, "aop_prune_config", aop_resume) # # 计时:mid -> last # if profile_enabled: # torch.cuda.synchronize() # t0_tail = time.perf_counter() # with torch.no_grad(), torch.autocast( # device_type="cuda", dtype=torch.bfloat16, enabled=True # ): # out_last = model.encoder( # return_dict=True, # output_hidden_states=False, # stop_at_layer=None, # resume_state=resume_state_subset, # compute_lm_head=False, # ) # if profile_enabled: # torch.cuda.synchronize() # t1_tail = time.perf_counter() # timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len( # cont_indices # ) # timing_stats["tail_num"] += len(cont_indices) # hs_last = out_last.last_hidden_state # if hs_last is None: # hs_last = out_last.hidden_states[-1] # am_last = getattr(out_last, "attention_mask", None) # if am_last is None: # am_last = resume_state_subset["attention_mask"] # reps_last_t = ( # model._pooling(hs_last, am_last) # .detach() # .to(dtype=torch.bfloat16) # ) # if use_local: # for i, idx_global in enumerate(cont_indices): # cand_local = infos[idx_global].get("cand_names", []) # if not cand_local: # cids = [] # else: # rows = [ # cand_id2row.get(str(cid), -1) # for cid in cand_local # ] # rows = [r for r in rows if r >= 0] # if len(rows) == 0: # cids = [] # else: # cmat_last_t = cand_last_t[rows] # scores_vec_t = ( # reps_last_t[i : i + 1] @ cmat_last_t.T # )[0] # order_local = ( # torch.argsort( # scores_vec_t, # dim=0, # descending=True, # ) # .cpu() # .numpy() # ) # cids = [str(cand_local[o]) for o in order_local] # label = ( # infos[idx_global].get("label_name") # or infos[idx_global].get("label") # or infos[idx_global].get("rel_docids") # ) # if not isinstance(label, list): # label = [label] # rel_scores = infos[idx_global].get("rel_scores", None) # global_idx = batch_start_idx + idx_global # results_dict[global_idx] = { # "prediction": cids, # "label": label, # "rel_scores": rel_scores, # } # else: # scores_last = reps_last_t @ cand_last_t.T # _, topk_inds = torch.topk( # scores_last, k=min(200, len(cand_ids)), dim=1 # ) # topk_inds = topk_inds.cpu().numpy() # for i, idx_global in enumerate(cont_indices): # cids = [cand_ids[k] for k in topk_inds[i]] # label = ( # infos[idx_global].get("label_name") # or infos[idx_global].get("label") # or infos[idx_global].get("rel_docids") # ) # if not isinstance(label, list): # label = [label] # rel_scores = infos[idx_global].get("rel_scores", None) # global_idx = batch_start_idx + idx_global # results_dict[global_idx] = { # "prediction": cids, # "label": label, # "rel_scores": rel_scores, # } # # --------------------------------------------------- # # 4. Profiling: 保存 per-query 的 topK embedding & 相似度 # # --------------------------------------------------- # if profile_enabled and is_main: # K = min(topk_emb, cand_mid_t.size(0)) # # 转到 float32 + CPU 便于写盘 # q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D] # q_last_cpu = ( # reps_last_full.detach().float().cpu() # if reps_last_full is not None # else None # ) # [B, D] # cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D] # cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D] # # 【关键修复】mid2mid 相似度(表征空间对齐) # scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc] # topk_mid_vals, topk_mid_inds = torch.topk( # scores_mid_full, k=K, dim=1 # ) # # last2last 相似度(如果有) # if q_last_cpu is not None: # scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc] # topk_last_vals, topk_last_inds = torch.topk( # scores_last_full, k=K, dim=1 # ) # else: # topk_last_vals = None # topk_last_inds = None # for i in range(B): # qid = batch_start_idx + i # rec = { # "qid": int(qid), # "early_exit": bool(exit_mask[i]), # } # if p_need_last_batch is not None: # rec["p_need_last"] = float(p_need_last_batch[i].item()) # # 【关键修复】mid2mid TopK(表征空间对齐) # mid_inds = topk_mid_inds[i].tolist() # mid_scores = topk_mid_vals[i].tolist() # rec["mid_topk_scores"] = mid_scores # rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds] # rec["mid_q_emb"] = q_mid_cpu[i].tolist() # rec["mid_cand_embs"] = cand_mid_cpu[mid_inds].tolist() # # last2last TopK(如果有) # if topk_last_inds is not None: # last_inds = topk_last_inds[i].tolist() # last_scores = topk_last_vals[i].tolist() # rec["last_topk_scores"] = last_scores # rec["last_topk_cand_ids"] = [ # cand_ids[j] for j in last_inds # ] # rec["last_q_emb"] = ( # q_last_cpu[i].tolist() if q_last_cpu is not None else None # ) # rec["last_cand_embs"] = cand_last_cpu[last_inds].tolist() # else: # rec["last_topk_scores"] = None # rec["last_topk_cand_ids"] = None # rec["last_q_emb"] = None # rec["last_cand_embs"] = None # analysis_records.append(rec) # # ===================================================== # # 5. 收集 & 保存结果 # # ===================================================== # for idx in sorted(results_dict.keys()): # pred_dicts.append(results_dict[idx]) # print_master( # f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} " # f"({stats['exit']/stats['total']:.2%})" # ) # metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] # score = RankingMetrics(metrics_to_report).evaluate(pred_dicts) # if is_main: # os.makedirs(out_dir, exist_ok=True) # with open( # os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), # "w", # ) as f: # json.dump(score, f, indent=4) # # 保存 profiling 信息 # if profile_enabled: # prof_dir = os.path.join(out_dir, "profiling") # os.makedirs(prof_dir, exist_ok=True) # # 时间统计:转换成均值(秒/样本) # mid_avg = ( # timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"]) # ) # tail_avg = ( # timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"]) # ) # timing_out = { # "mid_time_sum": timing_stats["mid_time_sum"], # "mid_num": timing_stats["mid_num"], # "tail_time_sum": timing_stats["tail_time_sum"], # "tail_num": timing_stats["tail_num"], # "avg_mid_time_per_query_sec": mid_avg, # "avg_tail_time_per_cont_query_sec": tail_avg, # "num_exit": int(stats["exit"]), # "num_total": int(stats["total"]), # } # with open( # os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w" # ) as f: # json.dump(timing_out, f, indent=2) # # embedding + 相似度记录(JSONL) # embed_path = os.path.join( # prof_dir, f"{dataset_name}_embeds.jsonl" # ) # with open(embed_path, "w") as f: # for rec in analysis_records: # f.write(json.dumps(rec) + "\n") # print_master( # f"[PROFILE] Saved timing to {prof_dir}, " # f"embeddings to {embed_path}" # ) # elapsed = time.time() - start_time # return score, elapsed # # =========================== # # Helper Functions (Pre-Computation) # # =========================== # def encode_candidates_both_layers( # model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, # model_args: ModelArguments, full_dataset: Dataset, mid_layer: int, # ): # local_rank = dist.get_rank() if dist.is_initialized() else 0 # model.eval() # all_mid, all_last, all_ids = [], [], [] # with torch.no_grad(): # for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0): # inputs = batch_to_device(inputs, training_args.device) # # 强制关闭 Cand 侧的 AOP # aop_cfg = getattr(model.encoder, "aop_prune_config", None) # _orig = None # if isinstance(aop_cfg, dict): # _orig = aop_cfg.get("enabled", False) # aop_cfg["enabled"] = False # setattr(model.encoder, "aop_prune_config", aop_cfg) # with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): # out = model.encoder(**inputs, return_dict=True, output_hidden_states=True, stop_at_layer=None) # if isinstance(aop_cfg, dict) and _orig is not None: # aop_cfg["enabled"] = _orig # Restore # mid_hs = out.hidden_states[mid_layer] # last_hs = out.hidden_states[-1] # am = inputs.get("attention_mask", None) # if am is not None and am.device != mid_hs.device: am = am.to(mid_hs.device) # reps_mid = model._pooling(mid_hs, am) # reps_last = model._pooling(last_hs, am) # all_mid.append(reps_mid.detach().float().cpu()) # all_last.append(reps_last.detach().float().cpu()) # all_ids.extend([info["cand_name"] for info in dataset_info]) # if not all_mid: return np.array([]), np.array([]), [] # return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids # # =========================== # # Main # # =========================== # def main(): # if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): # dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) # local_rank = dist.get_rank() if dist.is_initialized() else 0 # parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) # model_args, data_args, training_args = parser.parse_args_into_dataclasses() # ee_cfg = get_env_ee_config() # hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # if not getattr(model_args, "model_backbone", None): # model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) # setattr(model_args, 'model_backbone', model_backbone) # if local_rank == 0: # processor = load_processor(model_args, data_args) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # if torch.distributed.is_initialized(): torch.distributed.barrier() # if local_rank != 0: # processor = load_processor(model_args, data_args) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # model.eval() # model = model.to(training_args.device, dtype=torch.bfloat16) # # 调试:检查模型配置 # print_master(f"[DEBUG] Model normalize={model.normalize}, pooling={model.pooling}, temperature={getattr(model, 'temperature', 'N/A')}") # # AOP 配置注入 # aop_cfg = get_env_aop_config() # if aop_cfg["enabled"]: # setattr(model.encoder, "aop_prune_config", aop_cfg) # model.set_inference_layers(qry_layers=None, tgt_layers=None) # print_master(f"[AOP] Enabled: Layer={aop_cfg['layer_idx']}, Ratio={aop_cfg['keep_ratio']}") # # 加载分类器 # classifier = None # if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]: # classifier_path = ee_cfg['classifier_path'] # print_master(f"[EE] Loading Classifier from {classifier_path}...") # # 【关键】使用27维特征,与训练代码完全一致 # classifier = EarlyExitClassifier(input_dim=27, hidden_dim=64) # state_dict = None # # 尝试多种加载方式 # if os.path.isdir(classifier_path): # # 方式1: safetensors 格式(Trainer 默认) # safetensors_file = os.path.join(classifier_path, "model.safetensors") # if os.path.exists(safetensors_file): # from safetensors.torch import load_file # state_dict = load_file(safetensors_file) # print_master(f"[EE] Loaded from model.safetensors") # else: # # 方式2: pytorch_model.bin # pt_file = os.path.join(classifier_path, "pytorch_model.bin") # if os.path.exists(pt_file): # state_dict = torch.load(pt_file, map_location=training_args.device) # print_master(f"[EE] Loaded from pytorch_model.bin") # else: # # 方式3: 独立的 .pt 文件 # layer_idx = ee_cfg.get('layer', 12) # pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt") # if os.path.exists(pt_file): # state_dict = torch.load(pt_file, map_location=training_args.device) # print_master(f"[EE] Loaded from {os.path.basename(pt_file)}") # else: # raise FileNotFoundError(f"Cannot find classifier weights in {classifier_path}") # elif os.path.isfile(classifier_path): # # 方式4: 直接指定文件 # if classifier_path.endswith('.safetensors'): # from safetensors.torch import load_file # state_dict = load_file(classifier_path) # print_master(f"[EE] Loaded from .safetensors file") # else: # state_dict = torch.load(classifier_path, map_location=training_args.device) # print_master(f"[EE] Loaded from .pt file") # else: # raise FileNotFoundError(f"Classifier path not found: {classifier_path}") # classifier.load_state_dict(state_dict) # classifier.to(training_args.device) # classifier.eval() # print_master(f"[EE] Classifier loaded successfully. Threshold={ee_cfg['threshold']}") # with open(data_args.dataset_config, 'r') as yaml_file: # dataset_configs = yaml.safe_load(yaml_file) # for dataset_name, task_config in dataset_configs.items(): # if dist.is_initialized(): dist.barrier() # print_master(f"\n--- Evaluating {dataset_name} ---") # if data_args.data_basedir: # for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: # if task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) # full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) # full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) # mid_layer = int(ee_cfg["layer"]) # cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}") # cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast") # # 【关键修复】强制重新生成candidates以确保normalize设置一致 # force_regenerate = False # 已经重新生成过,设为False避免每次都重新计算 # # 预计算 Candidates # if force_regenerate or (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)): # print_master(f"[INFO] Regenerating candidates with normalize={model.normalize}...") # eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") # eval_cand_loader = DataLoader(full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) # cand_mid, cand_last, cand_ids = encode_candidates_both_layers(model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer) # if local_rank == 0: # with open(cand_mid_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f) # with open(cand_last_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f) # if dist.is_initialized(): dist.barrier() # if local_rank == 0: # with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f) # with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f) # # 【关键修复】从task_config读取eval_type,决定global/local ranking # rank_global = task_config.get("eval_type", "global") == "global" # print_master(f"[{dataset_name}] Eval type: {'global' if rank_global else 'local'} ranking") # run_early_exit_queries( # model, classifier, processor, model_args, data_args, training_args, # full_eval_qry_dataset, cand_mid_dict, cand_last_dict, ee_cfg, dataset_name, data_args.encode_output_path, # global_ranking=rank_global # ) # if dist.is_initialized(): dist.barrier() # if __name__ == '__main__': # main() # ################################################################################### # # 添加早停率上下界 # 调用格式: # export EE_EXIT_LOWER=0.05 # export EE_EXIT_UPPER=0.5 # import datetime # import logging # import json # import random # import time # import numpy as np # import os # import pickle # import sys # import torch # import torch.distributed as dist # import torch.nn.functional as F # import yaml # import transformers # import math # from torch.utils.data import DataLoader # from tqdm import tqdm # from transformers import HfArgumentParser, AutoConfig, AutoTokenizer # from datasets import Dataset, concatenate_datasets # from datasets.distributed import split_dataset_by_node # from src.arguments import ModelArguments, DataArguments, TrainingArguments # from src.data.collator.eval_collator import MultimodalEvalDataCollator # from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset # from src.eval_utils.metrics import RankingMetrics # from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel # from src.model.processor import get_backbone_name, load_processor, COLPALI # from src.utils import batch_to_device, print_rank, print_master # # 引入分类器 # from src.classifier_utils_V2 import EarlyExitClassifier # logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') # logger = logging.getLogger(__name__) # # =========================== # # Helper Functions (AOP Config) # # =========================== # def _parse_bool(v: str, default=False): # if v is None: return default # v = v.strip().lower() # return v in {"1","true","yes","y","t","on"} # def _parse_float(v: str, default=None): # try: return float(v) if v is not None else default # except: return default # def _parse_int(v: str, default=None): # try: return int(v) if v is not None else default # except: return default # def get_env_aop_config(): # enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) # apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) # mode = os.environ.get("AOP_MODE", "delta").strip().lower() # # Parameters # delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) # khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) # keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) # min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) # use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) # # Specific # prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) # prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) # keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) # keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) # selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() # attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() # if layer_idx is None and enabled: # enabled = False # return { # "enabled": enabled, "apply_to": apply_to, "layer_idx": layer_idx, "mode": mode, # "delta": delta, "K_hat": khat, "keep_ratio": keep_ratio, "min_keep": min_keep, # "use_bias": use_bias, "prune_vision": prune_vision, "prune_text": prune_text, # "keep_ratio_vision": keep_ratio_v, "keep_ratio_text": keep_ratio_t, # "selection": selection, "attn_agg": attn_agg, # "margin_mid": None # Simplified # } # def get_env_ee_config(): # ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"} # layer = int(os.environ.get("EE_LAYER", "12")) # method = os.environ.get("EE_METHOD", "classifier").strip().lower() # threshold = float(os.environ.get("EE_THRESHOLD", "0.8")) # classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "") # # ===== 新增:早停率上下界(0~1) ===== # exit_lower = float(os.environ.get("EE_EXIT_LOWER", "0.0")) # 下界,默认 0 # exit_upper = float(os.environ.get("EE_EXIT_UPPER", "1.0")) # 上界,默认 1 # # 合法化:保证 0 <= lower <= upper <= 1 # exit_lower = max(0.0, min(exit_lower, 1.0)) # exit_upper = max(exit_lower, min(exit_upper, 1.0)) # return dict( # enabled=ee_enabled, # layer=layer, # method=method, # threshold=threshold, # classifier_path=classifier_path, # exit_lower=exit_lower, # exit_upper=exit_upper, # ) # def pad_dataset_to_divisible(dataset, world_size): # num_samples = len(dataset) # if num_samples % world_size == 0: return dataset, num_samples # num_to_add = world_size - (num_samples % world_size) # padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) # return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add # # =========================== # # Core Inference Function # # =========================== # def run_early_exit_queries( # model: MMEBModel, # classifier: EarlyExitClassifier, # processor, # model_args: ModelArguments, # data_args: DataArguments, # training_args: TrainingArguments, # qry_dataset: Dataset, # cand_mid_dict: dict, # cand_last_dict: dict, # ee_cfg: dict, # dataset_name: str, # out_dir: str, # global_ranking: bool = True, # ): # device = training_args.device # local_rank = dist.get_rank() if dist.is_initialized() else 0 # is_main = (not dist.is_initialized()) or (local_rank == 0) # # ========================== # # 全局早停率上下界(按 bench) # # ========================== # total_samples = len(qry_dataset) # exit_lower = float(ee_cfg.get("exit_lower", 0.0)) # exit_upper = float(ee_cfg.get("exit_upper", 1.0)) # # 是否真正启用约束:如果是默认 [0,1] 就视为关闭,保持兼容 # use_exit_bounds = not (abs(exit_lower - 0.0) < 1e-6 and abs(exit_upper - 1.0) < 1e-6) # # 映射成「最少/最多早停的样本数」 # min_exit = int(math.ceil(exit_lower * total_samples)) # max_exit = int(math.floor(exit_upper * total_samples)) # # 安全裁剪 # min_exit = max(0, min(min_exit, total_samples)) # max_exit = max(min_exit, min(max_exit, total_samples)) # if is_main: # print_master( # f"[EE] Global exit-rate bounds for dataset '{dataset_name}': " # f"lower={exit_lower:.3f} ({min_exit}/{total_samples}), " # f"upper={exit_upper:.3f} ({max_exit}/{total_samples}), " # f"enabled={use_exit_bounds}" # ) # # ========================== # # Profiling 配置 # # ========================== # profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in { # "1", "true", "yes", "on", "y", "t" # } # topk_emb = int(os.environ.get("EE_TOPK_EMB", "5")) # # 运行时间统计(按样本数加权) # timing_stats = { # "mid_time_sum": 0.0, # encoder 到 mid 的总时间 * 样本数 # "mid_num": 0, # 样本数 # "tail_time_sum": 0.0, # mid->last 续跑部分的总时间 * 样本数 # "tail_num": 0, # 续跑样本数 # } # # embedding + 相似度记录(仅 rank0 保存) # analysis_records = [] if (profile_enabled and is_main) else None # # 1. 准备 Candidates # cand_ids = list(cand_mid_dict.keys()) # cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) # cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) # # 新增:作为检索打分使用的 FP32 Numpy 向量 # cand_mid_np = cand_mid # [Nc, D], float32 # cand_last_np = cand_last # [Nc, D], float32 # cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) # cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16) # collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") # loader = DataLoader( # qry_dataset, # batch_size=training_args.per_device_eval_batch_size, # collate_fn=collator, # num_workers=training_args.dataloader_num_workers # ) # pred_dicts = [] # stats = {"exit": 0, "total": 0} # threshold = float(ee_cfg["threshold"]) # method = ee_cfg["method"] # target_layer_idx = int(ee_cfg["layer"]) # # 结果顺序 # results_dict = {} # global_sample_idx = 0 # # Local vs Global ranking # use_local = (not global_ranking) # if use_local: # print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)") # cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} # else: # print_master(f"[INFO] Using GLOBAL ranking (full library)") # # --- AOP 配置初始化 --- # aop_cfg = getattr(model.encoder, "aop_prune_config", None) # _orig_enabled = None # side_enable = True # if isinstance(aop_cfg, dict) and aop_cfg: # _orig_enabled = aop_cfg.get("enabled", False) # apply_to = aop_cfg.get("apply_to", "qry") # side_enable = (apply_to == "both") or (apply_to == "qry") # model.eval() # if classifier: # classifier.eval() # classifier.to(device) # start_time = time.time() # for inputs, infos in tqdm( # loader, # desc=f"[EE+AOP] {dataset_name} (tau={threshold})", # disable=local_rank > 0, # ): # inputs = batch_to_device(inputs, device) # B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1 # batch_start_idx = global_sample_idx # global_sample_idx += B # stats["total"] += B # # --------------------------------------------------- # # 1. 前半程: Run to Mid Layer (含 AOP 动态控制) # # --------------------------------------------------- # orig_cfg = None # if isinstance(aop_cfg, dict) and aop_cfg: # orig_cfg = dict(aop_cfg) # aop_layer = aop_cfg.get("layer_idx", None) # aop_on_mid = bool( # _orig_enabled # and side_enable # and (aop_layer is not None) # and (aop_layer < target_layer_idx) # ) # aop_cfg_mid = dict(aop_cfg) # aop_cfg_mid["enabled"] = aop_on_mid # setattr(model.encoder, "aop_prune_config", aop_cfg_mid) # # 计时:encoder 到 mid 层 # if profile_enabled: # torch.cuda.synchronize() # t0_mid = time.perf_counter() # with torch.no_grad(), torch.autocast( # device_type="cuda", dtype=torch.bfloat16, enabled=True # ): # out_mid = model.encoder( # **inputs, # return_dict=True, # output_hidden_states=False, # stop_at_layer=target_layer_idx, # compute_lm_head=False, # ) # if profile_enabled: # torch.cuda.synchronize() # t1_mid = time.perf_counter() # timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B # timing_stats["mid_num"] += B # # 恢复 AOP 配置 # if isinstance(orig_cfg, dict): # setattr(model.encoder, "aop_prune_config", orig_cfg) # # Hidden State & Mask # hs_mid = getattr(out_mid, "last_hidden_state", None) # if hs_mid is None: # hs_mid = out_mid.hidden_states[-1] # am_mid = getattr(out_mid, "attention_mask", None) # if am_mid is None: # am_mid = inputs.get("attention_mask", None) # reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(dtype=torch.bfloat16) # # 如果要记录最后一层 embedding,需要在 profiling 模式下额外跑一次全 forward # reps_last_full = None # if profile_enabled: # with torch.no_grad(), torch.autocast( # device_type="cuda", dtype=torch.bfloat16, enabled=True # ): # out_full = model.encoder( # **inputs, # return_dict=True, # output_hidden_states=False, # stop_at_layer=None, # compute_lm_head=False, # ) # hs_last_full = getattr(out_full, "last_hidden_state", None) # if hs_last_full is None: # hs_last_full = out_full.hidden_states[-1] # am_last_full = getattr(out_full, "attention_mask", None) # if am_last_full is None: # am_last_full = inputs.get("attention_mask", None) # reps_last_full = ( # model._pooling(hs_last_full, am_last_full) # .detach() # .to(dtype=torch.bfloat16) # ) # # --------------------------------------------------- # # 2. 特征工程 + gating # # --------------------------------------------------- # exit_mask = np.zeros(B, dtype=bool) # p_need_last_batch = None # profiling / debug # orig_exit_mask = None # 原始按 threshold 判的结果(不带全局约束) # if method == "classifier" and classifier is not None: # with torch.no_grad(): # # 【关键修复】特征提取:qry_mid × cand_mid(表征空间对齐) # cos_mid = reps_mid_t @ cand_mid_t.T # [B, N] # backbone_ptr = ( # model.module if hasattr(model, "module") else model # ) # temp = getattr(backbone_ptr, "temperature", 0.02) # scores_mid = cos_mid / temp # probs_mid = torch.softmax(scores_mid, dim=1) # [B, N] # # ==== 以下是你原来的 27 维特征工程代码,原样保留 ==== # diag_cos = cos_mid.max(dim=1)[0] # sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True) # s2_cos = ( # sorted_cos[:, 1] # if sorted_cos.size(1) > 1 # else sorted_cos[:, 0] # ) # margin_mid = diag_cos - s2_cos # margin_mean = margin_mid.mean() # margin_std = margin_mid.std(unbiased=False) + 1e-6 # z_margin_mid = (margin_mid - margin_mean) / margin_std # margin_median = margin_mid.median() # mad = (margin_mid - margin_median).abs().median() + 1e-6 # mad_margin_mid = (margin_mid - margin_median) / mad # p1_mid = probs_mid.max(dim=1)[0] # H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1) # gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1) # TOPK = min(16, probs_mid.size(1)) # topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1) # topk_mean = topk_vals.mean(dim=1) # topk_std = topk_vals.std(dim=1, unbiased=False) # topk_cv = topk_std / (topk_mean + 1e-6) # centered = topk_vals - topk_mean.unsqueeze(1) # var = (centered ** 2).mean(dim=1) + 1e-6 # m4 = (centered ** 4).mean(dim=1) # topk_kurt = m4 / (var ** 2) # topk_med = topk_vals.median(dim=1).values # row_mean_cos = cos_mid.mean(dim=1) # row_med_cos = cos_mid.median(dim=1).values # s1_over_mean = diag_cos - row_mean_cos # s1_over_med = diag_cos - row_med_cos # sorted_probs, _ = torch.sort( # probs_mid, dim=1, descending=True # ) # p1 = sorted_probs[:, 0] # p2 = ( # sorted_probs[:, 1] # if sorted_probs.size(1) > 1 # else sorted_probs[:, 0] # ) # shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum( # dim=1 # ) # shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1) # R = min(10, sorted_probs.size(1)) # x = torch.arange( # R, device=device, dtype=sorted_probs.dtype # ) # x_centered = x - x.mean() # denom = (x_centered ** 2).sum() # y = torch.log(sorted_probs[:, :R] + 1e-6) # slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom # row_mean_p = probs_mid.mean(dim=1) # row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6 # z1 = (p1_mid - row_mean_p) / row_std_p # center_p = probs_mid - row_mean_p.unsqueeze(1) # m3 = (center_p ** 3).mean(dim=1) # skew = m3 / (row_std_p ** 3 + 1e-6) # s1_over_sk = p1_mid - skew # TAIL_K = min(10, sorted_probs.size(1)) # tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1) # HEAD_K = min(5, sorted_probs.size(1)) # head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1) # mask_ratio = torch.zeros_like(diag_cos) # mask_len = torch.zeros_like(diag_cos) # mask_runs = torch.zeros_like(diag_cos) # scalar_inputs = torch.stack( # [ # diag_cos, # s2_cos, # margin_mid, # z_margin_mid, # mad_margin_mid, # p1_mid, # H_mid, # gini_mid, # topk_mean, # topk_std, # topk_cv, # topk_kurt, # topk_med, # s1_over_mean, # s1_over_med, # p1, # p2, # shape_H, # shape_gini, # slope, # z1, # s1_over_sk, # tail_mean, # head5_mean, # mask_ratio, # mask_len, # mask_runs, # ], # dim=1, # ) # modality_idx = torch.zeros( # B, dtype=torch.long, device=device # ) # if "pixel_values" in inputs and inputs["pixel_values"] is not None: # pv = inputs["pixel_values"] # if isinstance(pv, list): # for i, item in enumerate(pv): # if item is not None: # modality_idx[i] = 1 # elif isinstance(pv, torch.Tensor) and pv.numel() > 0: # modality_idx.fill_(1) # logits = classifier(scalar_inputs, modality_idx) # p_need_last = torch.sigmoid(logits) # [B,1] # p_need_last_batch = p_need_last.squeeze(1) # [B] # # 原始按 threshold 的判定(仅作为参考) # should_exit = p_need_last_batch < threshold # orig_exit_mask = should_exit.cpu().numpy() # # ====== 应用全局早停率上下界(如果启用) ====== # if method == "classifier" and classifier is not None and orig_exit_mask is not None: # if not use_exit_bounds: # # 不启用上下界:行为与旧版本完全一致 # exit_mask = orig_exit_mask # else: # # 当前之前已经早停的数量 # E_prev = stats["exit"] # # 当前 batch 前已经处理的样本数(全局 index) # i_prev = batch_start_idx # B_cur = B # N_total = total_samples # # 最多还能早停多少个(不超过上界) # max_exit_remaining = max_exit - E_prev # if max_exit_remaining <= 0: # allowed_max = 0 # else: # allowed_max = min(B_cur, max_exit_remaining) # # 本 batch 之后还剩多少个样本没处理 # remaining_after = N_total - (i_prev + B_cur) # # 为了将来仍有机会达到下界,本 batch 至少需要早停多少个 # required_min = min( # B_cur, # max(0, min_exit - E_prev - remaining_after) # ) # # 分类器在当前 batch 原始建议早停个数(按 threshold) # n0 = int(orig_exit_mask.sum()) # if allowed_max <= 0: # n_exit = 0 # else: # # 在 [required_min, allowed_max] 区间内,尽量贴近原始 n0 # n_exit = max(required_min, min(n0, allowed_max)) # # 最终 exit_mask:在当前 batch 中选 p_need_last 最小的 n_exit 个样本早停 # exit_mask = np.zeros(B_cur, dtype=bool) # if n_exit > 0: # p_np = p_need_last_batch.detach().cpu().numpy() # [B] # order = np.argsort(p_np) # 升序(越小越容易早停) # chosen = order[:n_exit] # exit_mask[chosen] = True # # Debug:只在前几个 batch 打印 # if stats["total"] <= B * 3 and is_main: # raw_rate = orig_exit_mask.mean() # bounded_rate = exit_mask.mean() # print_master( # f"[EE Debug] Batch {stats['total']//B}: " # f"p_need_last mean={p_need_last_batch.mean().item():.4f}, " # f"std={p_need_last_batch.std().item():.4f}, " # f"raw_exit_rate={raw_rate:.2%}, " # f"bounded_exit_rate={bounded_rate:.2%}, " # f"required_min={required_min}, allowed_max={allowed_max}, " # f"n0={n0}, n_exit={n_exit}" # ) # # 非 classifier 或 classifier 不存在:保持 exit_mask=全 False(即无早停) # stats["exit"] += int(exit_mask.sum()) # # --------------------------------------------------- # # 3. 分支执行 # # --------------------------------------------------- # exit_indices = np.where(exit_mask)[0] # cont_indices = np.where(~exit_mask)[0] # # A. 早停样本:用 mid→last 检索 # if len(exit_indices) > 0: # reps_exit = reps_mid_t[exit_indices] # if use_local: # for i, idx in enumerate(exit_indices): # cand_local = infos[idx].get("cand_names", []) # if not cand_local: # cids = [] # else: # rows = [ # cand_id2row.get(str(cid), -1) # for cid in cand_local # ] # rows = [r for r in rows if r >= 0] # if len(rows) == 0: # cids = [] # else: # # CPU FP32:单个 query mid × 局部 cand_mid # cmat_np = cand_mid_np[rows] # [Nc_local, D] # qry_np = reps_exit[i].detach().float().cpu().numpy() # [D] # scores_vec = np.dot(qry_np, cmat_np.T) # [Nc_local] # top_k = min(200, len(rows)) # order_local = np.argsort(-scores_vec)[:top_k] # cids = [str(cand_local[o]) for o in order_local] # label = ( # infos[idx].get("label_name") # or infos[idx].get("label") # or infos[idx].get("rel_docids") # ) # if not isinstance(label, list): # label = [label] # rel_scores = infos[idx].get("rel_scores", None) # global_idx = batch_start_idx + idx # results_dict[global_idx] = { # "prediction": cids, # "label": label, # "rel_scores": rel_scores, # } # else: # # 使用 CPU FP32 Numpy:reps_exit_np × cand_mid_np.T # reps_exit_np = reps_exit.detach().float().cpu().numpy() # [B_exit, D] # scores_full = np.dot(reps_exit_np, cand_mid_np.T) # [B_exit, Nc] # top_k = min(200, len(cand_ids)) # # 全排序后截到 top_k # topk_inds = np.argsort(-scores_full, axis=1)[:, :top_k] # [B_exit, top_k] # for i, idx in enumerate(exit_indices): # cids = [cand_ids[k] for k in topk_inds[i]] # label = ( # infos[idx].get("label_name") # or infos[idx].get("label") # or infos[idx].get("rel_docids") # ) # if not isinstance(label, list): # label = [label] # rel_scores = infos[idx].get("rel_scores", None) # global_idx = batch_start_idx + idx # results_dict[global_idx] = { # "prediction": cids, # "label": label, # "rel_scores": rel_scores, # } # # B. 续跑样本:从 mid 继续到 last # if len(cont_indices) > 0: # interm = getattr(out_mid, "intermediate_state", None) # hs = interm["hidden_states"].detach() # am = interm["attention_mask"].detach() # pos = interm["position_ids"].detach() # vm = interm.get("vision_mask", None) # tm = interm.get("text_mask", None) # next_layer = int(interm["next_layer_idx"]) # resume_state_subset = { # "hidden_states": hs[cont_indices], # "attention_mask": am[cont_indices], # "position_ids": pos[:, cont_indices, :], # "vision_mask": vm[cont_indices] if vm is not None else None, # "text_mask": tm[cont_indices] if tm is not None else None, # "next_layer_idx": next_layer, # } # if isinstance(aop_cfg, dict) and aop_cfg: # aop_resume = dict(aop_cfg) # aop_resume["enabled"] = bool(_orig_enabled and side_enable) # setattr(model.encoder, "aop_prune_config", aop_resume) # # 计时:mid -> last # if profile_enabled: # torch.cuda.synchronize() # t0_tail = time.perf_counter() # with torch.no_grad(), torch.autocast( # device_type="cuda", dtype=torch.bfloat16, enabled=True # ): # out_last = model.encoder( # return_dict=True, # output_hidden_states=False, # stop_at_layer=None, # resume_state=resume_state_subset, # compute_lm_head=False, # ) # if profile_enabled: # torch.cuda.synchronize() # t1_tail = time.perf_counter() # timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len( # cont_indices # ) # timing_stats["tail_num"] += len(cont_indices) # hs_last = out_last.last_hidden_state # if hs_last is None: # hs_last = out_last.hidden_states[-1] # am_last = getattr(out_last, "attention_mask", None) # if am_last is None: # am_last = resume_state_subset["attention_mask"] # reps_last_t = ( # model._pooling(hs_last, am_last) # .detach() # .to(dtype=torch.bfloat16) # ) # if use_local: # for i, idx_global in enumerate(cont_indices): # cand_local = infos[idx_global].get("cand_names", []) # if not cand_local: # cids = [] # else: # rows = [ # cand_id2row.get(str(cid), -1) # for cid in cand_local # ] # rows = [r for r in rows if r >= 0] # if len(rows) == 0: # cids = [] # else: # # CPU FP32:单个 query last × 局部 cand_last # cmat_last_np = cand_last_np[rows] # [Nc_local, D] # qry_last_np = reps_last_t[i].detach().float().cpu().numpy() # scores_vec = np.dot(qry_last_np, cmat_last_np.T) # [Nc_local] # top_k = min(200, len(rows)) # order_local = np.argsort(-scores_vec)[:top_k] # cids = [str(cand_local[o]) for o in order_local] # label = ( # infos[idx_global].get("label_name") # or infos[idx_global].get("label") # or infos[idx_global].get("rel_docids") # ) # if not isinstance(label, list): # label = [label] # rel_scores = infos[idx_global].get("rel_scores", None) # global_idx = batch_start_idx + idx_global # results_dict[global_idx] = { # "prediction": cids, # "label": label, # "rel_scores": rel_scores, # } # else: # # CPU FP32:所有续跑 query last × 全量 cand_last # reps_last_np = reps_last_t.detach().float().cpu().numpy() # [B_cont, D] # scores_last = np.dot(reps_last_np, cand_last_np.T) # [B_cont, Nc] # top_k = min(200, len(cand_ids)) # topk_inds = np.argsort(-scores_last, axis=1)[:, :top_k] # [B_cont, top_k] # for i, idx_global in enumerate(cont_indices): # cids = [cand_ids[k] for k in topk_inds[i]] # label = ( # infos[idx_global].get("label_name") # or infos[idx_global].get("label") # or infos[idx_global].get("rel_docids") # ) # if not isinstance(label, list): # label = [label] # rel_scores = infos[idx_global].get("rel_scores", None) # global_idx = batch_start_idx + idx_global # results_dict[global_idx] = { # "prediction": cids, # "label": label, # "rel_scores": rel_scores, # } # # --------------------------------------------------- # # 4. Profiling: 保存 per-query 的 topK embedding & 相似度 # # --------------------------------------------------- # if profile_enabled and is_main: # K = min(topk_emb, cand_mid_t.size(0)) # # 转到 float32 + CPU 便于写盘 # q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D] # q_last_cpu = ( # reps_last_full.detach().float().cpu() # if reps_last_full is not None # else None # ) # [B, D] # cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D] # cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D] # # 【关键修复】mid2mid 相似度(表征空间对齐) # scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc] # topk_mid_vals, topk_mid_inds = torch.topk( # scores_mid_full, k=K, dim=1 # ) # # last2last 相似度(如果有) # if q_last_cpu is not None: # scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc] # topk_last_vals, topk_last_inds = torch.topk( # scores_last_full, k=K, dim=1 # ) # else: # topk_last_vals = None # topk_last_inds = None # for i in range(B): # qid = batch_start_idx + i # rec = { # "qid": int(qid), # "early_exit": bool(exit_mask[i]), # } # if p_need_last_batch is not None: # rec["p_need_last"] = float(p_need_last_batch[i].item()) # # 【关键修复】mid2mid TopK(表征空间对齐) # mid_inds = topk_mid_inds[i].tolist() # mid_scores = topk_mid_vals[i].tolist() # rec["mid_topk_scores"] = mid_scores # rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds] # rec["mid_q_emb"] = q_mid_cpu[i].tolist() # rec["mid_cand_embs"] = cand_mid_cpu[mid_inds].tolist() # # last2last TopK(如果有) # if topk_last_inds is not None: # last_inds = topk_last_inds[i].tolist() # last_scores = topk_last_vals[i].tolist() # rec["last_topk_scores"] = last_scores # rec["last_topk_cand_ids"] = [ # cand_ids[j] for j in last_inds # ] # rec["last_q_emb"] = ( # q_last_cpu[i].tolist() if q_last_cpu is not None else None # ) # rec["last_cand_embs"] = cand_last_cpu[last_inds].tolist() # else: # rec["last_topk_scores"] = None # rec["last_topk_cand_ids"] = None # rec["last_q_emb"] = None # rec["last_cand_embs"] = None # analysis_records.append(rec) # # ===================================================== # # 5. 收集 & 保存结果 # # ===================================================== # for idx in sorted(results_dict.keys()): # pred_dicts.append(results_dict[idx]) # print_master( # f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} " # f"({stats['exit']/stats['total']:.2%})" # ) # metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] # score = RankingMetrics(metrics_to_report).evaluate(pred_dicts) # if is_main: # os.makedirs(out_dir, exist_ok=True) # with open( # os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), # "w", # ) as f: # json.dump(score, f, indent=4) # # 保存 profiling 信息 # if profile_enabled: # prof_dir = os.path.join(out_dir, "profiling") # os.makedirs(prof_dir, exist_ok=True) # # 时间统计:转换成均值(秒/样本) # mid_avg = ( # timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"]) # ) # tail_avg = ( # timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"]) # ) # timing_out = { # "mid_time_sum": timing_stats["mid_time_sum"], # "mid_num": timing_stats["mid_num"], # "tail_time_sum": timing_stats["tail_time_sum"], # "tail_num": timing_stats["tail_num"], # "avg_mid_time_per_query_sec": mid_avg, # "avg_tail_time_per_cont_query_sec": tail_avg, # "num_exit": int(stats["exit"]), # "num_total": int(stats["total"]), # } # with open( # os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w" # ) as f: # json.dump(timing_out, f, indent=2) # # embedding + 相似度记录(JSONL) # embed_path = os.path.join( # prof_dir, f"{dataset_name}_embeds.jsonl" # ) # with open(embed_path, "w") as f: # for rec in analysis_records: # f.write(json.dumps(rec) + "\n") # print_master( # f"[PROFILE] Saved timing to {prof_dir}, " # f"embeddings to {embed_path}" # ) # elapsed = time.time() - start_time # return score, elapsed # # =========================== # # Helper Functions (Pre-Computation) # # =========================== # def encode_candidates_both_layers( # model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, # model_args: ModelArguments, full_dataset: Dataset, mid_layer: int, # ): # local_rank = dist.get_rank() if dist.is_initialized() else 0 # model.eval() # all_mid, all_last, all_ids = [], [], [] # with torch.no_grad(): # for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0): # inputs = batch_to_device(inputs, training_args.device) # # 强制关闭 Cand 侧的 AOP # aop_cfg = getattr(model.encoder, "aop_prune_config", None) # _orig = None # if isinstance(aop_cfg, dict): # _orig = aop_cfg.get("enabled", False) # aop_cfg["enabled"] = False # setattr(model.encoder, "aop_prune_config", aop_cfg) # with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): # out = model.encoder(**inputs, return_dict=True, output_hidden_states=True, stop_at_layer=None) # if isinstance(aop_cfg, dict) and _orig is not None: # aop_cfg["enabled"] = _orig # Restore # mid_hs = out.hidden_states[mid_layer] # last_hs = out.hidden_states[-1] # am = inputs.get("attention_mask", None) # if am is not None and am.device != mid_hs.device: am = am.to(mid_hs.device) # reps_mid = model._pooling(mid_hs, am) # reps_last = model._pooling(last_hs, am) # all_mid.append(reps_mid.detach().float().cpu()) # all_last.append(reps_last.detach().float().cpu()) # all_ids.extend([info["cand_name"] for info in dataset_info]) # if not all_mid: return np.array([]), np.array([]), [] # return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids # # =========================== # # Main # # =========================== # def main(): # if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): # dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) # local_rank = dist.get_rank() if dist.is_initialized() else 0 # parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) # model_args, data_args, training_args = parser.parse_args_into_dataclasses() # ee_cfg = get_env_ee_config() # hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # if not getattr(model_args, "model_backbone", None): # model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) # setattr(model_args, 'model_backbone', model_backbone) # if local_rank == 0: # processor = load_processor(model_args, data_args) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # if torch.distributed.is_initialized(): torch.distributed.barrier() # if local_rank != 0: # processor = load_processor(model_args, data_args) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # model.eval() # model = model.to(training_args.device, dtype=torch.bfloat16) # # 调试:检查模型配置 # print_master(f"[DEBUG] Model normalize={model.normalize}, pooling={model.pooling}, temperature={getattr(model, 'temperature', 'N/A')}") # # AOP 配置注入 # aop_cfg = get_env_aop_config() # if aop_cfg["enabled"]: # setattr(model.encoder, "aop_prune_config", aop_cfg) # model.set_inference_layers(qry_layers=None, tgt_layers=None) # print_master(f"[AOP] Enabled: Layer={aop_cfg['layer_idx']}, Ratio={aop_cfg['keep_ratio']}") # # 加载分类器 # classifier = None # if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]: # classifier_path = ee_cfg['classifier_path'] # print_master(f"[EE] Loading Classifier from {classifier_path}...") # # 【关键】使用27维特征,与训练代码完全一致 # classifier = EarlyExitClassifier(input_dim=27, hidden_dim=64) # state_dict = None # # 尝试多种加载方式 # if os.path.isdir(classifier_path): # # 方式1: safetensors 格式(Trainer 默认) # safetensors_file = os.path.join(classifier_path, "model.safetensors") # if os.path.exists(safetensors_file): # from safetensors.torch import load_file # state_dict = load_file(safetensors_file) # print_master(f"[EE] Loaded from model.safetensors") # else: # # 方式2: pytorch_model.bin # pt_file = os.path.join(classifier_path, "pytorch_model.bin") # if os.path.exists(pt_file): # state_dict = torch.load(pt_file, map_location=training_args.device) # print_master(f"[EE] Loaded from pytorch_model.bin") # else: # # 方式3: 独立的 .pt 文件 # layer_idx = ee_cfg.get('layer', 12) # pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt") # if os.path.exists(pt_file): # state_dict = torch.load(pt_file, map_location=training_args.device) # print_master(f"[EE] Loaded from {os.path.basename(pt_file)}") # else: # raise FileNotFoundError(f"Cannot find classifier weights in {classifier_path}") # elif os.path.isfile(classifier_path): # # 方式4: 直接指定文件 # if classifier_path.endswith('.safetensors'): # from safetensors.torch import load_file # state_dict = load_file(classifier_path) # print_master(f"[EE] Loaded from .safetensors file") # else: # state_dict = torch.load(classifier_path, map_location=training_args.device) # print_master(f"[EE] Loaded from .pt file") # else: # raise FileNotFoundError(f"Classifier path not found: {classifier_path}") # classifier.load_state_dict(state_dict) # classifier.to(training_args.device) # classifier.eval() # print_master(f"[EE] Classifier loaded successfully. Threshold={ee_cfg['threshold']}") # with open(data_args.dataset_config, 'r') as yaml_file: # dataset_configs = yaml.safe_load(yaml_file) # for dataset_name, task_config in dataset_configs.items(): # if dist.is_initialized(): dist.barrier() # print_master(f"\n--- Evaluating {dataset_name} ---") # if data_args.data_basedir: # for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: # if task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) # full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) # full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) # mid_layer = int(ee_cfg["layer"]) # cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}") # cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast") # # 【关键修复】强制重新生成candidates以确保normalize设置一致 # force_regenerate = False # 已经重新生成过,设为False避免每次都重新计算 # # 预计算 Candidates # if force_regenerate or (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)): # print_master(f"[INFO] Regenerating candidates with normalize={model.normalize}...") # eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") # eval_cand_loader = DataLoader(full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) # cand_mid, cand_last, cand_ids = encode_candidates_both_layers(model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer) # if local_rank == 0: # with open(cand_mid_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f) # with open(cand_last_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f) # if dist.is_initialized(): dist.barrier() # if local_rank == 0: # with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f) # with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f) # # 【关键修复】从task_config读取eval_type,决定global/local ranking # rank_global = task_config.get("eval_type", "global") == "global" # print_master(f"[{dataset_name}] Eval type: {'global' if rank_global else 'local'} ranking") # run_early_exit_queries( # model, classifier, processor, model_args, data_args, training_args, # full_eval_qry_dataset, cand_mid_dict, cand_last_dict, ee_cfg, dataset_name, data_args.encode_output_path, # global_ranking=rank_global # ) # if dist.is_initialized(): dist.barrier() # if __name__ == '__main__': # main() ####################################################################################################### #解决CPU&GPU计算精度问题 import datetime import logging import json import random import time import numpy as np import os import pickle import sys import torch import torch.distributed as dist import torch.nn.functional as F import yaml import transformers import math from torch.utils.data import DataLoader from tqdm import tqdm from transformers import HfArgumentParser, AutoConfig, AutoTokenizer from datasets import Dataset, concatenate_datasets from datasets.distributed import split_dataset_by_node from src.arguments import ModelArguments, DataArguments, TrainingArguments from src.data.collator.eval_collator import MultimodalEvalDataCollator from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset from src.eval_utils.metrics import RankingMetrics from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel from src.model.processor import get_backbone_name, load_processor, COLPALI from src.utils import batch_to_device, print_rank, print_master # 引入分类器 from src.classifier_utils_V2 import EarlyExitClassifier logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') logger = logging.getLogger(__name__) # =========================== # Per-dataset thresholds(示例) # key 必须和 dataset_config YAML 里的 dataset_name 完全一致 # 未出现在此表的 dataset 使用 EE_THRESHOLD 作为默认阈值 # =========================== PER_DATASET_THRESHOLDS = { # 举例(自己根据离线 sweep 的结果填写): "CIRR": 0, "EDIS": 5e-16, "FashionIQ": 0, "MSCOCO_i2t": 1e-16, "MSCOCO_t2i": 1e-5, "NIGHTS": 1, "OVEN": 1e-11, "VisDial": 1e-11, "VisualNews_i2t": 5e-28, "VisualNews_t2i": 1e-16, "WebQA": 5e-11, "Wiki-SS-NQ": 1e-16, } # =========================== # Helper Functions (AOP Config) # =========================== def _parse_bool(v: str, default=False): if v is None: return default v = v.strip().lower() return v in {"1","true","yes","y","t","on"} def _parse_float(v: str, default=None): try: return float(v) if v is not None else default except: return default def _parse_int(v: str, default=None): try: return int(v) if v is not None else default except: return default def get_env_aop_config(): enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) mode = os.environ.get("AOP_MODE", "delta").strip().lower() # Parameters delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) # Specific prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() if layer_idx is None and enabled: enabled = False return { "enabled": enabled, "apply_to": apply_to, "layer_idx": layer_idx, "mode": mode, "delta": delta, "K_hat": khat, "keep_ratio": keep_ratio, "min_keep": min_keep, "use_bias": use_bias, "prune_vision": prune_vision, "prune_text": prune_text, "keep_ratio_vision": keep_ratio_v, "keep_ratio_text": keep_ratio_t, "selection": selection, "attn_agg": attn_agg, "margin_mid": None # Simplified } def get_env_ee_config(): ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"} layer = int(os.environ.get("EE_LAYER", "12")) method = os.environ.get("EE_METHOD", "classifier").strip().lower() threshold = float(os.environ.get("EE_THRESHOLD", "0.8")) classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "") return dict(enabled=ee_enabled, layer=layer, method=method, threshold=threshold, classifier_path=classifier_path) def pad_dataset_to_divisible(dataset, world_size): num_samples = len(dataset) if num_samples % world_size == 0: return dataset, num_samples num_to_add = world_size - (num_samples % world_size) padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add # =========================== # Core Inference Function # =========================== def run_early_exit_queries( model: MMEBModel, classifier: EarlyExitClassifier, processor, model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, qry_dataset: Dataset, cand_mid_dict: dict, cand_last_dict: dict, ee_cfg: dict, dataset_name: str, out_dir: str, global_ranking: bool = True, ): device = training_args.device local_rank = dist.get_rank() if dist.is_initialized() else 0 is_main = (not dist.is_initialized()) or (local_rank == 0) # ========================== # Profiling 配置 # ========================== profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in { "1", "true", "yes", "on", "y", "t" } topk_emb = int(os.environ.get("EE_TOPK_EMB", "5")) # 运行时间统计(按样本数加权) timing_stats = { "mid_time_sum": 0.0, # encoder 到 mid 的总时间 * 样本数 "mid_num": 0, # 样本数 "tail_time_sum": 0.0, # mid->last 续跑部分的总时间 * 样本数 "tail_num": 0, # 续跑样本数 } # embedding + 相似度记录(仅 rank0 保存) analysis_records = [] if (profile_enabled and is_main) else None # 1. 准备 Candidates cand_ids = list(cand_mid_dict.keys()) cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) # 新增:作为检索打分使用的 FP32 Numpy 向量 cand_mid_np = cand_mid # [Nc, D], float32 cand_last_np = cand_last # [Nc, D], float32 cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16) collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") loader = DataLoader( qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=collator, num_workers=training_args.dataloader_num_workers ) pred_dicts = [] stats = {"exit": 0, "total": 0} threshold = float(ee_cfg["threshold"]) method = ee_cfg["method"] target_layer_idx = int(ee_cfg["layer"]) # 结果顺序 results_dict = {} global_sample_idx = 0 # Local vs Global ranking use_local = (not global_ranking) if use_local: print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)") cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} else: print_master(f"[INFO] Using GLOBAL ranking (full library)") # --- AOP 配置初始化 --- aop_cfg = getattr(model.encoder, "aop_prune_config", None) _orig_enabled = None side_enable = True if isinstance(aop_cfg, dict) and aop_cfg: _orig_enabled = aop_cfg.get("enabled", False) apply_to = aop_cfg.get("apply_to", "qry") side_enable = (apply_to == "both") or (apply_to == "qry") model.eval() if classifier: classifier.eval() classifier.to(device) start_time = time.time() for inputs, infos in tqdm( loader, desc=f"[EE+AOP] {dataset_name} (tau={threshold})", disable=local_rank > 0, ): inputs = batch_to_device(inputs, device) B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1 batch_start_idx = global_sample_idx global_sample_idx += B stats["total"] += B # --------------------------------------------------- # 1. 前半程: Run to Mid Layer (含 AOP 动态控制) # --------------------------------------------------- orig_cfg = None if isinstance(aop_cfg, dict) and aop_cfg: orig_cfg = dict(aop_cfg) aop_layer = aop_cfg.get("layer_idx", None) aop_on_mid = bool( _orig_enabled and side_enable and (aop_layer is not None) and (aop_layer < target_layer_idx) ) aop_cfg_mid = dict(aop_cfg) aop_cfg_mid["enabled"] = aop_on_mid setattr(model.encoder, "aop_prune_config", aop_cfg_mid) # 计时:encoder 到 mid 层 if profile_enabled: torch.cuda.synchronize() t0_mid = time.perf_counter() with torch.no_grad(), torch.autocast( device_type="cuda", dtype=torch.bfloat16, enabled=True ): out_mid = model.encoder( **inputs, return_dict=True, output_hidden_states=False, stop_at_layer=target_layer_idx, compute_lm_head=False, ) if profile_enabled: torch.cuda.synchronize() t1_mid = time.perf_counter() timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B timing_stats["mid_num"] += B # 恢复 AOP 配置 if isinstance(orig_cfg, dict): setattr(model.encoder, "aop_prune_config", orig_cfg) # Hidden State & Mask hs_mid = getattr(out_mid, "last_hidden_state", None) if hs_mid is None: hs_mid = out_mid.hidden_states[-1] am_mid = getattr(out_mid, "attention_mask", None) if am_mid is None: am_mid = inputs.get("attention_mask", None) reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(dtype=torch.bfloat16) # 如果要记录最后一层 embedding,需要在 profiling 模式下额外跑一次全 forward reps_last_full = None if profile_enabled: with torch.no_grad(), torch.autocast( device_type="cuda", dtype=torch.bfloat16, enabled=True ): out_full = model.encoder( **inputs, return_dict=True, output_hidden_states=False, stop_at_layer=None, compute_lm_head=False, ) hs_last_full = getattr(out_full, "last_hidden_state", None) if hs_last_full is None: hs_last_full = out_full.hidden_states[-1] am_last_full = getattr(out_full, "attention_mask", None) if am_last_full is None: am_last_full = inputs.get("attention_mask", None) reps_last_full = ( model._pooling(hs_last_full, am_last_full) .detach() .to(dtype=torch.bfloat16) ) # --------------------------------------------------- # 2. 特征工程 + gating # --------------------------------------------------- exit_mask = np.zeros(B, dtype=bool) p_need_last_batch = None # 仅 profiling 或 debug 时保存 if method == "classifier" and classifier is not None: with torch.no_grad(): # 【关键修复】特征提取:qry_mid × cand_mid(表征空间对齐) cos_mid = reps_mid_t @ cand_mid_t.T # [B, N] backbone_ptr = ( model.module if hasattr(model, "module") else model ) temp = getattr(backbone_ptr, "temperature", 0.02) scores_mid = cos_mid / temp probs_mid = torch.softmax(scores_mid, dim=1) # [B, N] diag_cos = cos_mid.max(dim=1)[0] sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True) s2_cos = ( sorted_cos[:, 1] if sorted_cos.size(1) > 1 else sorted_cos[:, 0] ) margin_mid = diag_cos - s2_cos margin_mean = margin_mid.mean() margin_std = margin_mid.std(unbiased=False) + 1e-6 z_margin_mid = (margin_mid - margin_mean) / margin_std margin_median = margin_mid.median() mad = (margin_mid - margin_median).abs().median() + 1e-6 mad_margin_mid = (margin_mid - margin_median) / mad p1_mid = probs_mid.max(dim=1)[0] H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1) gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1) TOPK = min(16, probs_mid.size(1)) topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1) topk_mean = topk_vals.mean(dim=1) topk_std = topk_vals.std(dim=1, unbiased=False) topk_cv = topk_std / (topk_mean + 1e-6) centered = topk_vals - topk_mean.unsqueeze(1) var = (centered ** 2).mean(dim=1) + 1e-6 m4 = (centered ** 4).mean(dim=1) topk_kurt = m4 / (var ** 2) topk_med = topk_vals.median(dim=1).values row_mean_cos = cos_mid.mean(dim=1) row_med_cos = cos_mid.median(dim=1).values s1_over_mean = diag_cos - row_mean_cos s1_over_med = diag_cos - row_med_cos sorted_probs, _ = torch.sort( probs_mid, dim=1, descending=True ) p1 = sorted_probs[:, 0] p2 = ( sorted_probs[:, 1] if sorted_probs.size(1) > 1 else sorted_probs[:, 0] ) shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum( dim=1 ) shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1) R = min(10, sorted_probs.size(1)) x = torch.arange( R, device=device, dtype=sorted_probs.dtype ) x_centered = x - x.mean() denom = (x_centered ** 2).sum() y = torch.log(sorted_probs[:, :R] + 1e-6) slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom row_mean_p = probs_mid.mean(dim=1) row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6 z1 = (p1_mid - row_mean_p) / row_std_p center_p = probs_mid - row_mean_p.unsqueeze(1) m3 = (center_p ** 3).mean(dim=1) skew = m3 / (row_std_p ** 3 + 1e-6) s1_over_sk = p1_mid - skew TAIL_K = min(10, sorted_probs.size(1)) tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1) HEAD_K = min(5, sorted_probs.size(1)) head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1) mask_ratio = torch.zeros_like(diag_cos) mask_len = torch.zeros_like(diag_cos) mask_runs = torch.zeros_like(diag_cos) scalar_inputs = torch.stack( [ diag_cos, s2_cos, margin_mid, z_margin_mid, mad_margin_mid, p1_mid, H_mid, gini_mid, topk_mean, topk_std, topk_cv, topk_kurt, topk_med, s1_over_mean, s1_over_med, p1, p2, shape_H, shape_gini, slope, z1, s1_over_sk, tail_mean, head5_mean, mask_ratio, mask_len, mask_runs, ], dim=1, ) modality_idx = torch.zeros( B, dtype=torch.long, device=device ) if "pixel_values" in inputs and inputs["pixel_values"] is not None: pv = inputs["pixel_values"] if isinstance(pv, list): for i, item in enumerate(pv): if item is not None: modality_idx[i] = 1 elif isinstance(pv, torch.Tensor) and pv.numel() > 0: modality_idx.fill_(1) logits = classifier(scalar_inputs, modality_idx) p_need_last = torch.sigmoid(logits) # [B,1] p_need_last_batch = p_need_last.squeeze(1) # [B] should_exit = p_need_last_batch < threshold exit_mask = should_exit.cpu().numpy() if stats["total"] <= B * 3 and is_main: print_master( f"[EE Debug] Batch {stats['total']//B}: " f"p_need_last mean={p_need_last_batch.mean().item():.4f}, " f"std={p_need_last_batch.std().item():.4f}, " f"Exit Rate={exit_mask.mean():.2%}, " f"Top3 Feats: diag_cos={diag_cos.mean():.3f}, " f"margin={margin_mid.mean():.3f}, H={H_mid.mean():.3f}" ) stats["exit"] += exit_mask.sum() # --------------------------------------------------- # 3. 分支执行 # --------------------------------------------------- exit_indices = np.where(exit_mask)[0] cont_indices = np.where(~exit_mask)[0] # A. 早停样本:用 mid→last 检索 if len(exit_indices) > 0: reps_exit = reps_mid_t[exit_indices] if use_local: for i, idx in enumerate(exit_indices): cand_local = infos[idx].get("cand_names", []) if not cand_local: cids = [] else: rows = [ cand_id2row.get(str(cid), -1) for cid in cand_local ] rows = [r for r in rows if r >= 0] if len(rows) == 0: cids = [] else: # CPU FP32:单个 query mid × 局部 cand_mid cmat_np = cand_mid_np[rows] # [Nc_local, D] qry_np = reps_exit[i].detach().float().cpu().numpy() # [D] scores_vec = np.dot(qry_np, cmat_np.T) # [Nc_local] top_k = min(200, len(rows)) order_local = np.argsort(-scores_vec)[:top_k] cids = [str(cand_local[o]) for o in order_local] label = ( infos[idx].get("label_name") or infos[idx].get("label") or infos[idx].get("rel_docids") ) if not isinstance(label, list): label = [label] rel_scores = infos[idx].get("rel_scores", None) global_idx = batch_start_idx + idx results_dict[global_idx] = { "prediction": cids, "label": label, "rel_scores": rel_scores, } else: # 使用 CPU FP32 Numpy:reps_exit_np × cand_mid_np.T reps_exit_np = reps_exit.detach().float().cpu().numpy() # [B_exit, D] scores_full = np.dot(reps_exit_np, cand_mid_np.T) # [B_exit, Nc] top_k = min(200, len(cand_ids)) # 全排序后截到 top_k topk_inds = np.argsort(-scores_full, axis=1)[:, :top_k] # [B_exit, top_k] for i, idx in enumerate(exit_indices): cids = [cand_ids[k] for k in topk_inds[i]] label = ( infos[idx].get("label_name") or infos[idx].get("label") or infos[idx].get("rel_docids") ) if not isinstance(label, list): label = [label] rel_scores = infos[idx].get("rel_scores", None) global_idx = batch_start_idx + idx results_dict[global_idx] = { "prediction": cids, "label": label, "rel_scores": rel_scores, } # B. 续跑样本:从 mid 继续到 last if len(cont_indices) > 0: interm = getattr(out_mid, "intermediate_state", None) hs = interm["hidden_states"].detach() am = interm["attention_mask"].detach() pos = interm["position_ids"].detach() vm = interm.get("vision_mask", None) tm = interm.get("text_mask", None) next_layer = int(interm["next_layer_idx"]) resume_state_subset = { "hidden_states": hs[cont_indices], "attention_mask": am[cont_indices], "position_ids": pos[:, cont_indices, :], "vision_mask": vm[cont_indices] if vm is not None else None, "text_mask": tm[cont_indices] if tm is not None else None, "next_layer_idx": next_layer, } if isinstance(aop_cfg, dict) and aop_cfg: aop_resume = dict(aop_cfg) aop_resume["enabled"] = bool(_orig_enabled and side_enable) setattr(model.encoder, "aop_prune_config", aop_resume) # 计时:mid -> last if profile_enabled: torch.cuda.synchronize() t0_tail = time.perf_counter() with torch.no_grad(), torch.autocast( device_type="cuda", dtype=torch.bfloat16, enabled=True ): out_last = model.encoder( return_dict=True, output_hidden_states=False, stop_at_layer=None, resume_state=resume_state_subset, compute_lm_head=False, ) if profile_enabled: torch.cuda.synchronize() t1_tail = time.perf_counter() timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len( cont_indices ) timing_stats["tail_num"] += len(cont_indices) hs_last = out_last.last_hidden_state if hs_last is None: hs_last = out_last.hidden_states[-1] am_last = getattr(out_last, "attention_mask", None) if am_last is None: am_last = resume_state_subset["attention_mask"] reps_last_t = ( model._pooling(hs_last, am_last) .detach() .to(dtype=torch.bfloat16) ) if use_local: for i, idx_global in enumerate(cont_indices): cand_local = infos[idx_global].get("cand_names", []) if not cand_local: cids = [] else: rows = [ cand_id2row.get(str(cid), -1) for cid in cand_local ] rows = [r for r in rows if r >= 0] if len(rows) == 0: cids = [] else: # CPU FP32:单个 query last × 局部 cand_last cmat_last_np = cand_last_np[rows] # [Nc_local, D] qry_last_np = reps_last_t[i].detach().float().cpu().numpy() scores_vec = np.dot(qry_last_np, cmat_last_np.T) # [Nc_local] top_k = min(200, len(rows)) order_local = np.argsort(-scores_vec)[:top_k] cids = [str(cand_local[o]) for o in order_local] label = ( infos[idx_global].get("label_name") or infos[idx_global].get("label") or infos[idx_global].get("rel_docids") ) if not isinstance(label, list): label = [label] rel_scores = infos[idx_global].get("rel_scores", None) global_idx = batch_start_idx + idx_global results_dict[global_idx] = { "prediction": cids, "label": label, "rel_scores": rel_scores, } else: # CPU FP32:所有续跑 query last × 全量 cand_last reps_last_np = reps_last_t.detach().float().cpu().numpy() # [B_cont, D] scores_last = np.dot(reps_last_np, cand_last_np.T) # [B_cont, Nc] top_k = min(200, len(cand_ids)) topk_inds = np.argsort(-scores_last, axis=1)[:, :top_k] # [B_cont, top_k] for i, idx_global in enumerate(cont_indices): cids = [cand_ids[k] for k in topk_inds[i]] label = ( infos[idx_global].get("label_name") or infos[idx_global].get("label") or infos[idx_global].get("rel_docids") ) if not isinstance(label, list): label = [label] rel_scores = infos[idx_global].get("rel_scores", None) global_idx = batch_start_idx + idx_global results_dict[global_idx] = { "prediction": cids, "label": label, "rel_scores": rel_scores, } # --------------------------------------------------- # 4. Profiling: 保存 per-query 的 topK embedding & 相似度 # --------------------------------------------------- if profile_enabled and is_main: K = min(topk_emb, cand_mid_t.size(0)) # 转到 float32 + CPU 便于写盘 q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D] q_last_cpu = ( reps_last_full.detach().float().cpu() if reps_last_full is not None else None ) # [B, D] cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D] cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D] # 【关键修复】mid2mid 相似度(表征空间对齐) scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc] topk_mid_vals, topk_mid_inds = torch.topk( scores_mid_full, k=K, dim=1 ) # last2last 相似度(如果有) if q_last_cpu is not None: scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc] topk_last_vals, topk_last_inds = torch.topk( scores_last_full, k=K, dim=1 ) else: topk_last_vals = None topk_last_inds = None for i in range(B): qid = batch_start_idx + i rec = { "qid": int(qid), "early_exit": bool(exit_mask[i]), } if p_need_last_batch is not None: rec["p_need_last"] = float(p_need_last_batch[i].item()) # 【关键修复】mid2mid TopK(表征空间对齐) mid_inds = topk_mid_inds[i].tolist() mid_scores = topk_mid_vals[i].tolist() rec["mid_topk_scores"] = mid_scores rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds] rec["mid_q_emb"] = q_mid_cpu[i].tolist() rec["mid_cand_embs"] = cand_mid_cpu[mid_inds].tolist() # last2last TopK(如果有) if topk_last_inds is not None: last_inds = topk_last_inds[i].tolist() last_scores = topk_last_vals[i].tolist() rec["last_topk_scores"] = last_scores rec["last_topk_cand_ids"] = [ cand_ids[j] for j in last_inds ] rec["last_q_emb"] = ( q_last_cpu[i].tolist() if q_last_cpu is not None else None ) rec["last_cand_embs"] = cand_last_cpu[last_inds].tolist() else: rec["last_topk_scores"] = None rec["last_topk_cand_ids"] = None rec["last_q_emb"] = None rec["last_cand_embs"] = None analysis_records.append(rec) # ===================================================== # 5. 收集 & 保存结果 # ===================================================== for idx in sorted(results_dict.keys()): pred_dicts.append(results_dict[idx]) print_master( f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} " f"({stats['exit']/stats['total']:.2%})" ) metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] score = RankingMetrics(metrics_to_report).evaluate(pred_dicts) if is_main: os.makedirs(out_dir, exist_ok=True) with open( os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w", ) as f: json.dump(score, f, indent=4) # 保存 profiling 信息 if profile_enabled: prof_dir = os.path.join(out_dir, "profiling") os.makedirs(prof_dir, exist_ok=True) # 时间统计:转换成均值(秒/样本) mid_avg = ( timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"]) ) tail_avg = ( timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"]) ) timing_out = { "mid_time_sum": timing_stats["mid_time_sum"], "mid_num": timing_stats["mid_num"], "tail_time_sum": timing_stats["tail_time_sum"], "tail_num": timing_stats["tail_num"], "avg_mid_time_per_query_sec": mid_avg, "avg_tail_time_per_cont_query_sec": tail_avg, "num_exit": int(stats["exit"]), "num_total": int(stats["total"]), } with open( os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w" ) as f: json.dump(timing_out, f, indent=2) # embedding + 相似度记录(JSONL) embed_path = os.path.join( prof_dir, f"{dataset_name}_embeds.jsonl" ) with open(embed_path, "w") as f: for rec in analysis_records: f.write(json.dumps(rec) + "\n") print_master( f"[PROFILE] Saved timing to {prof_dir}, " f"embeddings to {embed_path}" ) elapsed = time.time() - start_time return score, elapsed # =========================== # Helper Functions (Pre-Computation) # =========================== def encode_candidates_both_layers( model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, model_args: ModelArguments, full_dataset: Dataset, mid_layer: int, ): local_rank = dist.get_rank() if dist.is_initialized() else 0 model.eval() all_mid, all_last, all_ids = [], [], [] with torch.no_grad(): for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0): inputs = batch_to_device(inputs, training_args.device) # 强制关闭 Cand 侧的 AOP aop_cfg = getattr(model.encoder, "aop_prune_config", None) _orig = None if isinstance(aop_cfg, dict): _orig = aop_cfg.get("enabled", False) aop_cfg["enabled"] = False setattr(model.encoder, "aop_prune_config", aop_cfg) with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): out = model.encoder(**inputs, return_dict=True, output_hidden_states=True, stop_at_layer=None) if isinstance(aop_cfg, dict) and _orig is not None: aop_cfg["enabled"] = _orig # Restore mid_hs = out.hidden_states[mid_layer] last_hs = out.hidden_states[-1] am = inputs.get("attention_mask", None) if am is not None and am.device != mid_hs.device: am = am.to(mid_hs.device) reps_mid = model._pooling(mid_hs, am) reps_last = model._pooling(last_hs, am) all_mid.append(reps_mid.detach().float().cpu()) all_last.append(reps_last.detach().float().cpu()) all_ids.extend([info["cand_name"] for info in dataset_info]) if not all_mid: return np.array([]), np.array([]), [] return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids # =========================== # Main # =========================== def main(): if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) local_rank = dist.get_rank() if dist.is_initialized() else 0 parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() ee_cfg = get_env_ee_config() hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) if not getattr(model_args, "model_backbone", None): model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) setattr(model_args, 'model_backbone', model_backbone) if local_rank == 0: processor = load_processor(model_args, data_args) model = MMEBModel.load(model_args, is_trainable=False, processor=processor) if torch.distributed.is_initialized(): torch.distributed.barrier() if local_rank != 0: processor = load_processor(model_args, data_args) model = MMEBModel.load(model_args, is_trainable=False, processor=processor) model.eval() model = model.to(training_args.device, dtype=torch.bfloat16) # 调试:检查模型配置 print_master(f"[DEBUG] Model normalize={model.normalize}, pooling={model.pooling}, temperature={getattr(model, 'temperature', 'N/A')}") # AOP 配置注入 aop_cfg = get_env_aop_config() if aop_cfg["enabled"]: setattr(model.encoder, "aop_prune_config", aop_cfg) model.set_inference_layers(qry_layers=None, tgt_layers=None) print_master(f"[AOP] Enabled: Layer={aop_cfg['layer_idx']}, Ratio={aop_cfg['keep_ratio']}") # 加载分类器 classifier = None if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]: classifier_path = ee_cfg['classifier_path'] print_master(f"[EE] Loading Classifier from {classifier_path}...") # 【关键】使用27维特征,与训练代码完全一致 classifier = EarlyExitClassifier(input_dim=27, hidden_dim=64) state_dict = None # 尝试多种加载方式 if os.path.isdir(classifier_path): # 方式1: safetensors 格式(Trainer 默认) safetensors_file = os.path.join(classifier_path, "model.safetensors") if os.path.exists(safetensors_file): from safetensors.torch import load_file state_dict = load_file(safetensors_file) print_master(f"[EE] Loaded from model.safetensors") else: # 方式2: pytorch_model.bin pt_file = os.path.join(classifier_path, "pytorch_model.bin") if os.path.exists(pt_file): state_dict = torch.load(pt_file, map_location=training_args.device) print_master(f"[EE] Loaded from pytorch_model.bin") else: # 方式3: 独立的 .pt 文件 layer_idx = ee_cfg.get('layer', 12) pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt") if os.path.exists(pt_file): state_dict = torch.load(pt_file, map_location=training_args.device) print_master(f"[EE] Loaded from {os.path.basename(pt_file)}") else: raise FileNotFoundError(f"Cannot find classifier weights in {classifier_path}") elif os.path.isfile(classifier_path): # 方式4: 直接指定文件 if classifier_path.endswith('.safetensors'): from safetensors.torch import load_file state_dict = load_file(classifier_path) print_master(f"[EE] Loaded from .safetensors file") else: state_dict = torch.load(classifier_path, map_location=training_args.device) print_master(f"[EE] Loaded from .pt file") else: raise FileNotFoundError(f"Classifier path not found: {classifier_path}") classifier.load_state_dict(state_dict) classifier.to(training_args.device) classifier.eval() print_master(f"[EE] Classifier loaded successfully. Threshold={ee_cfg['threshold']}") with open(data_args.dataset_config, 'r') as yaml_file: dataset_configs = yaml.safe_load(yaml_file) for dataset_name, task_config in dataset_configs.items(): if dist.is_initialized(): dist.barrier() print_master(f"\n--- Evaluating {dataset_name} ---") # ===== 为当前数据集选择专用阈值 τ_ds ===== base_tau = float(ee_cfg["threshold"]) # 环境变量 EE_THRESHOLD 作为默认 ds_tau = PER_DATASET_THRESHOLDS.get(dataset_name, base_tau) # 为当前 dataset 拷贝一份 ee_cfg,并覆盖 threshold ee_cfg_ds = dict(ee_cfg) ee_cfg_ds["threshold"] = float(ds_tau) print_master( f"[EE] Dataset '{dataset_name}' use threshold τ={ee_cfg_ds['threshold']:.6e} " f"(default={base_tau:.6e})" ) if data_args.data_basedir: for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: if task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate( model_args=model_args, data_args=data_args, **task_config ) full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) mid_layer = int(ee_cfg_ds["layer"]) cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}") cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast") force_regenerate = False if force_regenerate or (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)): print_master(f"[INFO] Regenerating candidates with normalize={model.normalize}...") eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") eval_cand_loader = DataLoader( full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers, ) cand_mid, cand_last, cand_ids = encode_candidates_both_layers( model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer ) if local_rank == 0: with open(cand_mid_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f) with open(cand_last_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f) if dist.is_initialized(): dist.barrier() if local_rank == 0: with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f) with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f) rank_global = task_config.get("eval_type", "global") == "global" print_master(f"[{dataset_name}] Eval type: {'global' if rank_global else 'local'} ranking") # 关键:传 ee_cfg_ds,而不是全局 ee_cfg run_early_exit_queries( model, classifier, processor, model_args, data_args, training_args, full_eval_qry_dataset, cand_mid_dict, cand_last_dict, ee_cfg_ds, dataset_name, data_args.encode_output_path, global_ranking=rank_global, ) if dist.is_initialized(): dist.barrier() if dist.is_initialized(): dist.barrier() if __name__ == '__main__': main()