code_SAS_VLM2Vec / eval_test_time_with_classifier_V2.py
MgGladys's picture
Add files using upload-large-folder tool
2a40e7a verified
#######################################################################################################
#原始版本
# import datetime
# import logging
# import json
# import random
# import time
# import numpy as np
# import os
# import pickle
# import sys
# import torch
# import torch.distributed as dist
# import torch.nn.functional as F
# import yaml
# import transformers
# import math
# from torch.utils.data import DataLoader
# from tqdm import tqdm
# from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
# from datasets import Dataset, concatenate_datasets
# from datasets.distributed import split_dataset_by_node
# from src.arguments import ModelArguments, DataArguments, TrainingArguments
# from src.data.collator.eval_collator import MultimodalEvalDataCollator
# from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
# from src.eval_utils.metrics import RankingMetrics
# from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel
# from src.model.processor import get_backbone_name, load_processor, COLPALI
# from src.utils import batch_to_device, print_rank, print_master
# # 引入分类器
# from src.classifier_utils_V2 import EarlyExitClassifier
# logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
# logger = logging.getLogger(__name__)
# # ===========================
# # Helper Functions (AOP Config)
# # ===========================
# def _parse_bool(v: str, default=False):
# if v is None: return default
# v = v.strip().lower()
# return v in {"1","true","yes","y","t","on"}
# def _parse_float(v: str, default=None):
# try: return float(v) if v is not None else default
# except: return default
# def _parse_int(v: str, default=None):
# try: return int(v) if v is not None else default
# except: return default
# def get_env_aop_config():
# enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
# apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower()
# layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
# mode = os.environ.get("AOP_MODE", "delta").strip().lower()
# # Parameters
# delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
# khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
# keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
# min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
# use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
# # Specific
# prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
# prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)
# keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
# keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
# selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
# attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower()
# if layer_idx is None and enabled:
# enabled = False
# return {
# "enabled": enabled, "apply_to": apply_to, "layer_idx": layer_idx, "mode": mode,
# "delta": delta, "K_hat": khat, "keep_ratio": keep_ratio, "min_keep": min_keep,
# "use_bias": use_bias, "prune_vision": prune_vision, "prune_text": prune_text,
# "keep_ratio_vision": keep_ratio_v, "keep_ratio_text": keep_ratio_t,
# "selection": selection, "attn_agg": attn_agg,
# "margin_mid": None # Simplified
# }
# def get_env_ee_config():
# ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"}
# layer = int(os.environ.get("EE_LAYER", "12"))
# method = os.environ.get("EE_METHOD", "classifier").strip().lower()
# threshold = float(os.environ.get("EE_THRESHOLD", "0.8"))
# classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "")
# return dict(enabled=ee_enabled, layer=layer, method=method, threshold=threshold, classifier_path=classifier_path)
# def pad_dataset_to_divisible(dataset, world_size):
# num_samples = len(dataset)
# if num_samples % world_size == 0: return dataset, num_samples
# num_to_add = world_size - (num_samples % world_size)
# padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
# return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add
# # ===========================
# # Core Inference Function
# # ===========================
# def run_early_exit_queries(
# model: MMEBModel,
# classifier: EarlyExitClassifier,
# processor,
# model_args: ModelArguments,
# data_args: DataArguments,
# training_args: TrainingArguments,
# qry_dataset: Dataset,
# cand_mid_dict: dict,
# cand_last_dict: dict,
# ee_cfg: dict,
# dataset_name: str,
# out_dir: str,
# global_ranking: bool = True,
# ):
# device = training_args.device
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# is_main = (not dist.is_initialized()) or (local_rank == 0)
# # ==========================
# # Profiling 配置
# # ==========================
# profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in {
# "1", "true", "yes", "on", "y", "t"
# }
# topk_emb = int(os.environ.get("EE_TOPK_EMB", "5"))
# # 运行时间统计(按样本数加权)
# timing_stats = {
# "mid_time_sum": 0.0, # encoder 到 mid 的总时间 * 样本数
# "mid_num": 0, # 样本数
# "tail_time_sum": 0.0, # mid->last 续跑部分的总时间 * 样本数
# "tail_num": 0, # 续跑样本数
# }
# # embedding + 相似度记录(仅 rank0 保存)
# analysis_records = [] if (profile_enabled and is_main) else None
# # 1. 准备 Candidates
# cand_ids = list(cand_mid_dict.keys())
# cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32)
# cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32)
# cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16)
# cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16)
# collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
# loader = DataLoader(
# qry_dataset,
# batch_size=training_args.per_device_eval_batch_size,
# collate_fn=collator,
# num_workers=training_args.dataloader_num_workers
# )
# pred_dicts = []
# stats = {"exit": 0, "total": 0}
# threshold = float(ee_cfg["threshold"])
# method = ee_cfg["method"]
# target_layer_idx = int(ee_cfg["layer"])
# # 结果顺序
# results_dict = {}
# global_sample_idx = 0
# # Local vs Global ranking
# use_local = (not global_ranking)
# if use_local:
# print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)")
# cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)}
# else:
# print_master(f"[INFO] Using GLOBAL ranking (full library)")
# # --- AOP 配置初始化 ---
# aop_cfg = getattr(model.encoder, "aop_prune_config", None)
# _orig_enabled = None
# side_enable = True
# if isinstance(aop_cfg, dict) and aop_cfg:
# _orig_enabled = aop_cfg.get("enabled", False)
# apply_to = aop_cfg.get("apply_to", "qry")
# side_enable = (apply_to == "both") or (apply_to == "qry")
# model.eval()
# if classifier:
# classifier.eval()
# classifier.to(device)
# start_time = time.time()
# for inputs, infos in tqdm(
# loader,
# desc=f"[EE+AOP] {dataset_name} (tau={threshold})",
# disable=local_rank > 0,
# ):
# inputs = batch_to_device(inputs, device)
# B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1
# batch_start_idx = global_sample_idx
# global_sample_idx += B
# stats["total"] += B
# # ---------------------------------------------------
# # 1. 前半程: Run to Mid Layer (含 AOP 动态控制)
# # ---------------------------------------------------
# orig_cfg = None
# if isinstance(aop_cfg, dict) and aop_cfg:
# orig_cfg = dict(aop_cfg)
# aop_layer = aop_cfg.get("layer_idx", None)
# aop_on_mid = bool(
# _orig_enabled
# and side_enable
# and (aop_layer is not None)
# and (aop_layer < target_layer_idx)
# )
# aop_cfg_mid = dict(aop_cfg)
# aop_cfg_mid["enabled"] = aop_on_mid
# setattr(model.encoder, "aop_prune_config", aop_cfg_mid)
# # 计时:encoder 到 mid 层
# if profile_enabled:
# torch.cuda.synchronize()
# t0_mid = time.perf_counter()
# with torch.no_grad(), torch.autocast(
# device_type="cuda", dtype=torch.bfloat16, enabled=True
# ):
# out_mid = model.encoder(
# **inputs,
# return_dict=True,
# output_hidden_states=False,
# stop_at_layer=target_layer_idx,
# compute_lm_head=False,
# )
# if profile_enabled:
# torch.cuda.synchronize()
# t1_mid = time.perf_counter()
# timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B
# timing_stats["mid_num"] += B
# # 恢复 AOP 配置
# if isinstance(orig_cfg, dict):
# setattr(model.encoder, "aop_prune_config", orig_cfg)
# # Hidden State & Mask
# hs_mid = getattr(out_mid, "last_hidden_state", None)
# if hs_mid is None:
# hs_mid = out_mid.hidden_states[-1]
# am_mid = getattr(out_mid, "attention_mask", None)
# if am_mid is None:
# am_mid = inputs.get("attention_mask", None)
# reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(dtype=torch.bfloat16)
# # 如果要记录最后一层 embedding,需要在 profiling 模式下额外跑一次全 forward
# reps_last_full = None
# if profile_enabled:
# with torch.no_grad(), torch.autocast(
# device_type="cuda", dtype=torch.bfloat16, enabled=True
# ):
# out_full = model.encoder(
# **inputs,
# return_dict=True,
# output_hidden_states=False,
# stop_at_layer=None,
# compute_lm_head=False,
# )
# hs_last_full = getattr(out_full, "last_hidden_state", None)
# if hs_last_full is None:
# hs_last_full = out_full.hidden_states[-1]
# am_last_full = getattr(out_full, "attention_mask", None)
# if am_last_full is None:
# am_last_full = inputs.get("attention_mask", None)
# reps_last_full = (
# model._pooling(hs_last_full, am_last_full)
# .detach()
# .to(dtype=torch.bfloat16)
# )
# # ---------------------------------------------------
# # 2. 特征工程 + gating
# # ---------------------------------------------------
# exit_mask = np.zeros(B, dtype=bool)
# p_need_last_batch = None # 仅 profiling 或 debug 时保存
# if method == "classifier" and classifier is not None:
# with torch.no_grad():
# # 【关键修复】特征提取:qry_mid × cand_mid(表征空间对齐)
# cos_mid = reps_mid_t @ cand_mid_t.T # [B, N]
# backbone_ptr = (
# model.module if hasattr(model, "module") else model
# )
# temp = getattr(backbone_ptr, "temperature", 0.02)
# scores_mid = cos_mid / temp
# probs_mid = torch.softmax(scores_mid, dim=1) # [B, N]
# diag_cos = cos_mid.max(dim=1)[0]
# sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True)
# s2_cos = (
# sorted_cos[:, 1]
# if sorted_cos.size(1) > 1
# else sorted_cos[:, 0]
# )
# margin_mid = diag_cos - s2_cos
# margin_mean = margin_mid.mean()
# margin_std = margin_mid.std(unbiased=False) + 1e-6
# z_margin_mid = (margin_mid - margin_mean) / margin_std
# margin_median = margin_mid.median()
# mad = (margin_mid - margin_median).abs().median() + 1e-6
# mad_margin_mid = (margin_mid - margin_median) / mad
# p1_mid = probs_mid.max(dim=1)[0]
# H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1)
# gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1)
# TOPK = min(16, probs_mid.size(1))
# topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1)
# topk_mean = topk_vals.mean(dim=1)
# topk_std = topk_vals.std(dim=1, unbiased=False)
# topk_cv = topk_std / (topk_mean + 1e-6)
# centered = topk_vals - topk_mean.unsqueeze(1)
# var = (centered ** 2).mean(dim=1) + 1e-6
# m4 = (centered ** 4).mean(dim=1)
# topk_kurt = m4 / (var ** 2)
# topk_med = topk_vals.median(dim=1).values
# row_mean_cos = cos_mid.mean(dim=1)
# row_med_cos = cos_mid.median(dim=1).values
# s1_over_mean = diag_cos - row_mean_cos
# s1_over_med = diag_cos - row_med_cos
# sorted_probs, _ = torch.sort(
# probs_mid, dim=1, descending=True
# )
# p1 = sorted_probs[:, 0]
# p2 = (
# sorted_probs[:, 1]
# if sorted_probs.size(1) > 1
# else sorted_probs[:, 0]
# )
# shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(
# dim=1
# )
# shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1)
# R = min(10, sorted_probs.size(1))
# x = torch.arange(
# R, device=device, dtype=sorted_probs.dtype
# )
# x_centered = x - x.mean()
# denom = (x_centered ** 2).sum()
# y = torch.log(sorted_probs[:, :R] + 1e-6)
# slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom
# row_mean_p = probs_mid.mean(dim=1)
# row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6
# z1 = (p1_mid - row_mean_p) / row_std_p
# center_p = probs_mid - row_mean_p.unsqueeze(1)
# m3 = (center_p ** 3).mean(dim=1)
# skew = m3 / (row_std_p ** 3 + 1e-6)
# s1_over_sk = p1_mid - skew
# TAIL_K = min(10, sorted_probs.size(1))
# tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1)
# HEAD_K = min(5, sorted_probs.size(1))
# head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1)
# mask_ratio = torch.zeros_like(diag_cos)
# mask_len = torch.zeros_like(diag_cos)
# mask_runs = torch.zeros_like(diag_cos)
# scalar_inputs = torch.stack(
# [
# diag_cos,
# s2_cos,
# margin_mid,
# z_margin_mid,
# mad_margin_mid,
# p1_mid,
# H_mid,
# gini_mid,
# topk_mean,
# topk_std,
# topk_cv,
# topk_kurt,
# topk_med,
# s1_over_mean,
# s1_over_med,
# p1,
# p2,
# shape_H,
# shape_gini,
# slope,
# z1,
# s1_over_sk,
# tail_mean,
# head5_mean,
# mask_ratio,
# mask_len,
# mask_runs,
# ],
# dim=1,
# )
# modality_idx = torch.zeros(
# B, dtype=torch.long, device=device
# )
# if "pixel_values" in inputs and inputs["pixel_values"] is not None:
# pv = inputs["pixel_values"]
# if isinstance(pv, list):
# for i, item in enumerate(pv):
# if item is not None:
# modality_idx[i] = 1
# elif isinstance(pv, torch.Tensor) and pv.numel() > 0:
# modality_idx.fill_(1)
# logits = classifier(scalar_inputs, modality_idx)
# p_need_last = torch.sigmoid(logits) # [B,1]
# p_need_last_batch = p_need_last.squeeze(1) # [B]
# should_exit = p_need_last_batch < threshold
# exit_mask = should_exit.cpu().numpy()
# if stats["total"] <= B * 3 and is_main:
# print_master(
# f"[EE Debug] Batch {stats['total']//B}: "
# f"p_need_last mean={p_need_last_batch.mean().item():.4f}, "
# f"std={p_need_last_batch.std().item():.4f}, "
# f"Exit Rate={exit_mask.mean():.2%}, "
# f"Top3 Feats: diag_cos={diag_cos.mean():.3f}, "
# f"margin={margin_mid.mean():.3f}, H={H_mid.mean():.3f}"
# )
# stats["exit"] += exit_mask.sum()
# # ---------------------------------------------------
# # 3. 分支执行
# # ---------------------------------------------------
# exit_indices = np.where(exit_mask)[0]
# cont_indices = np.where(~exit_mask)[0]
# # A. 早停样本:用 mid→last 检索
# if len(exit_indices) > 0:
# reps_exit = reps_mid_t[exit_indices]
# if use_local:
# for i, idx in enumerate(exit_indices):
# cand_local = infos[idx].get("cand_names", [])
# if not cand_local:
# cids = []
# else:
# rows = [
# cand_id2row.get(str(cid), -1)
# for cid in cand_local
# ]
# rows = [r for r in rows if r >= 0]
# if len(rows) == 0:
# cids = []
# else:
# # 【关键修复】早停检索:qry_mid × cand_mid(表征空间对齐)
# cmat_t = cand_mid_t[rows]
# scores_vec = (reps_exit[i : i + 1] @ cmat_t.T)[0]
# order_local = (
# torch.argsort(
# scores_vec,
# dim=0,
# descending=True,
# )
# .cpu()
# .numpy()
# )
# cids = [str(cand_local[o]) for o in order_local]
# label = (
# infos[idx].get("label_name")
# or infos[idx].get("label")
# or infos[idx].get("rel_docids")
# )
# if not isinstance(label, list):
# label = [label]
# rel_scores = infos[idx].get("rel_scores", None)
# global_idx = batch_start_idx + idx
# results_dict[global_idx] = {
# "prediction": cids,
# "label": label,
# "rel_scores": rel_scores,
# }
# else:
# # 【关键修复】早停检索:qry_mid × cand_mid(表征空间对齐)
# scores_full = reps_exit @ cand_mid_t.T
# _, topk_inds = torch.topk(
# scores_full, k=min(200, len(cand_ids)), dim=1
# )
# topk_inds = topk_inds.cpu().numpy()
# for i, idx in enumerate(exit_indices):
# cids = [cand_ids[k] for k in topk_inds[i]]
# label = (
# infos[idx].get("label_name")
# or infos[idx].get("label")
# or infos[idx].get("rel_docids")
# )
# if not isinstance(label, list):
# label = [label]
# rel_scores = infos[idx].get("rel_scores", None)
# global_idx = batch_start_idx + idx
# results_dict[global_idx] = {
# "prediction": cids,
# "label": label,
# "rel_scores": rel_scores,
# }
# # B. 续跑样本:从 mid 继续到 last
# if len(cont_indices) > 0:
# interm = getattr(out_mid, "intermediate_state", None)
# hs = interm["hidden_states"].detach()
# am = interm["attention_mask"].detach()
# pos = interm["position_ids"].detach()
# vm = interm.get("vision_mask", None)
# tm = interm.get("text_mask", None)
# next_layer = int(interm["next_layer_idx"])
# resume_state_subset = {
# "hidden_states": hs[cont_indices],
# "attention_mask": am[cont_indices],
# "position_ids": pos[:, cont_indices, :],
# "vision_mask": vm[cont_indices] if vm is not None else None,
# "text_mask": tm[cont_indices] if tm is not None else None,
# "next_layer_idx": next_layer,
# }
# if isinstance(aop_cfg, dict) and aop_cfg:
# aop_resume = dict(aop_cfg)
# aop_resume["enabled"] = bool(_orig_enabled and side_enable)
# setattr(model.encoder, "aop_prune_config", aop_resume)
# # 计时:mid -> last
# if profile_enabled:
# torch.cuda.synchronize()
# t0_tail = time.perf_counter()
# with torch.no_grad(), torch.autocast(
# device_type="cuda", dtype=torch.bfloat16, enabled=True
# ):
# out_last = model.encoder(
# return_dict=True,
# output_hidden_states=False,
# stop_at_layer=None,
# resume_state=resume_state_subset,
# compute_lm_head=False,
# )
# if profile_enabled:
# torch.cuda.synchronize()
# t1_tail = time.perf_counter()
# timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len(
# cont_indices
# )
# timing_stats["tail_num"] += len(cont_indices)
# hs_last = out_last.last_hidden_state
# if hs_last is None:
# hs_last = out_last.hidden_states[-1]
# am_last = getattr(out_last, "attention_mask", None)
# if am_last is None:
# am_last = resume_state_subset["attention_mask"]
# reps_last_t = (
# model._pooling(hs_last, am_last)
# .detach()
# .to(dtype=torch.bfloat16)
# )
# if use_local:
# for i, idx_global in enumerate(cont_indices):
# cand_local = infos[idx_global].get("cand_names", [])
# if not cand_local:
# cids = []
# else:
# rows = [
# cand_id2row.get(str(cid), -1)
# for cid in cand_local
# ]
# rows = [r for r in rows if r >= 0]
# if len(rows) == 0:
# cids = []
# else:
# cmat_last_t = cand_last_t[rows]
# scores_vec_t = (
# reps_last_t[i : i + 1] @ cmat_last_t.T
# )[0]
# order_local = (
# torch.argsort(
# scores_vec_t,
# dim=0,
# descending=True,
# )
# .cpu()
# .numpy()
# )
# cids = [str(cand_local[o]) for o in order_local]
# label = (
# infos[idx_global].get("label_name")
# or infos[idx_global].get("label")
# or infos[idx_global].get("rel_docids")
# )
# if not isinstance(label, list):
# label = [label]
# rel_scores = infos[idx_global].get("rel_scores", None)
# global_idx = batch_start_idx + idx_global
# results_dict[global_idx] = {
# "prediction": cids,
# "label": label,
# "rel_scores": rel_scores,
# }
# else:
# scores_last = reps_last_t @ cand_last_t.T
# _, topk_inds = torch.topk(
# scores_last, k=min(200, len(cand_ids)), dim=1
# )
# topk_inds = topk_inds.cpu().numpy()
# for i, idx_global in enumerate(cont_indices):
# cids = [cand_ids[k] for k in topk_inds[i]]
# label = (
# infos[idx_global].get("label_name")
# or infos[idx_global].get("label")
# or infos[idx_global].get("rel_docids")
# )
# if not isinstance(label, list):
# label = [label]
# rel_scores = infos[idx_global].get("rel_scores", None)
# global_idx = batch_start_idx + idx_global
# results_dict[global_idx] = {
# "prediction": cids,
# "label": label,
# "rel_scores": rel_scores,
# }
# # ---------------------------------------------------
# # 4. Profiling: 保存 per-query 的 topK embedding & 相似度
# # ---------------------------------------------------
# if profile_enabled and is_main:
# K = min(topk_emb, cand_mid_t.size(0))
# # 转到 float32 + CPU 便于写盘
# q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D]
# q_last_cpu = (
# reps_last_full.detach().float().cpu()
# if reps_last_full is not None
# else None
# ) # [B, D]
# cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D]
# cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D]
# # 【关键修复】mid2mid 相似度(表征空间对齐)
# scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc]
# topk_mid_vals, topk_mid_inds = torch.topk(
# scores_mid_full, k=K, dim=1
# )
# # last2last 相似度(如果有)
# if q_last_cpu is not None:
# scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc]
# topk_last_vals, topk_last_inds = torch.topk(
# scores_last_full, k=K, dim=1
# )
# else:
# topk_last_vals = None
# topk_last_inds = None
# for i in range(B):
# qid = batch_start_idx + i
# rec = {
# "qid": int(qid),
# "early_exit": bool(exit_mask[i]),
# }
# if p_need_last_batch is not None:
# rec["p_need_last"] = float(p_need_last_batch[i].item())
# # 【关键修复】mid2mid TopK(表征空间对齐)
# mid_inds = topk_mid_inds[i].tolist()
# mid_scores = topk_mid_vals[i].tolist()
# rec["mid_topk_scores"] = mid_scores
# rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds]
# rec["mid_q_emb"] = q_mid_cpu[i].tolist()
# rec["mid_cand_embs"] = cand_mid_cpu[mid_inds].tolist()
# # last2last TopK(如果有)
# if topk_last_inds is not None:
# last_inds = topk_last_inds[i].tolist()
# last_scores = topk_last_vals[i].tolist()
# rec["last_topk_scores"] = last_scores
# rec["last_topk_cand_ids"] = [
# cand_ids[j] for j in last_inds
# ]
# rec["last_q_emb"] = (
# q_last_cpu[i].tolist() if q_last_cpu is not None else None
# )
# rec["last_cand_embs"] = cand_last_cpu[last_inds].tolist()
# else:
# rec["last_topk_scores"] = None
# rec["last_topk_cand_ids"] = None
# rec["last_q_emb"] = None
# rec["last_cand_embs"] = None
# analysis_records.append(rec)
# # =====================================================
# # 5. 收集 & 保存结果
# # =====================================================
# for idx in sorted(results_dict.keys()):
# pred_dicts.append(results_dict[idx])
# print_master(
# f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} "
# f"({stats['exit']/stats['total']:.2%})"
# )
# metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
# score = RankingMetrics(metrics_to_report).evaluate(pred_dicts)
# if is_main:
# os.makedirs(out_dir, exist_ok=True)
# with open(
# os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"),
# "w",
# ) as f:
# json.dump(score, f, indent=4)
# # 保存 profiling 信息
# if profile_enabled:
# prof_dir = os.path.join(out_dir, "profiling")
# os.makedirs(prof_dir, exist_ok=True)
# # 时间统计:转换成均值(秒/样本)
# mid_avg = (
# timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"])
# )
# tail_avg = (
# timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"])
# )
# timing_out = {
# "mid_time_sum": timing_stats["mid_time_sum"],
# "mid_num": timing_stats["mid_num"],
# "tail_time_sum": timing_stats["tail_time_sum"],
# "tail_num": timing_stats["tail_num"],
# "avg_mid_time_per_query_sec": mid_avg,
# "avg_tail_time_per_cont_query_sec": tail_avg,
# "num_exit": int(stats["exit"]),
# "num_total": int(stats["total"]),
# }
# with open(
# os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w"
# ) as f:
# json.dump(timing_out, f, indent=2)
# # embedding + 相似度记录(JSONL)
# embed_path = os.path.join(
# prof_dir, f"{dataset_name}_embeds.jsonl"
# )
# with open(embed_path, "w") as f:
# for rec in analysis_records:
# f.write(json.dumps(rec) + "\n")
# print_master(
# f"[PROFILE] Saved timing to {prof_dir}, "
# f"embeddings to {embed_path}"
# )
# elapsed = time.time() - start_time
# return score, elapsed
# # ===========================
# # Helper Functions (Pre-Computation)
# # ===========================
# def encode_candidates_both_layers(
# model: MMEBModel, loader: DataLoader, training_args: TrainingArguments,
# model_args: ModelArguments, full_dataset: Dataset, mid_layer: int,
# ):
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# model.eval()
# all_mid, all_last, all_ids = [], [], []
# with torch.no_grad():
# for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0):
# inputs = batch_to_device(inputs, training_args.device)
# # 强制关闭 Cand 侧的 AOP
# aop_cfg = getattr(model.encoder, "aop_prune_config", None)
# _orig = None
# if isinstance(aop_cfg, dict):
# _orig = aop_cfg.get("enabled", False)
# aop_cfg["enabled"] = False
# setattr(model.encoder, "aop_prune_config", aop_cfg)
# with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
# out = model.encoder(**inputs, return_dict=True, output_hidden_states=True, stop_at_layer=None)
# if isinstance(aop_cfg, dict) and _orig is not None:
# aop_cfg["enabled"] = _orig # Restore
# mid_hs = out.hidden_states[mid_layer]
# last_hs = out.hidden_states[-1]
# am = inputs.get("attention_mask", None)
# if am is not None and am.device != mid_hs.device: am = am.to(mid_hs.device)
# reps_mid = model._pooling(mid_hs, am)
# reps_last = model._pooling(last_hs, am)
# all_mid.append(reps_mid.detach().float().cpu())
# all_last.append(reps_last.detach().float().cpu())
# all_ids.extend([info["cand_name"] for info in dataset_info])
# if not all_mid: return np.array([]), np.array([]), []
# return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids
# # ===========================
# # Main
# # ===========================
# def main():
# if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
# dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
# model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# ee_cfg = get_env_ee_config()
# hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# if not getattr(model_args, "model_backbone", None):
# model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
# setattr(model_args, 'model_backbone', model_backbone)
# if local_rank == 0:
# processor = load_processor(model_args, data_args)
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# if torch.distributed.is_initialized(): torch.distributed.barrier()
# if local_rank != 0:
# processor = load_processor(model_args, data_args)
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# model.eval()
# model = model.to(training_args.device, dtype=torch.bfloat16)
# # 调试:检查模型配置
# print_master(f"[DEBUG] Model normalize={model.normalize}, pooling={model.pooling}, temperature={getattr(model, 'temperature', 'N/A')}")
# # AOP 配置注入
# aop_cfg = get_env_aop_config()
# if aop_cfg["enabled"]:
# setattr(model.encoder, "aop_prune_config", aop_cfg)
# model.set_inference_layers(qry_layers=None, tgt_layers=None)
# print_master(f"[AOP] Enabled: Layer={aop_cfg['layer_idx']}, Ratio={aop_cfg['keep_ratio']}")
# # 加载分类器
# classifier = None
# if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]:
# classifier_path = ee_cfg['classifier_path']
# print_master(f"[EE] Loading Classifier from {classifier_path}...")
# # 【关键】使用27维特征,与训练代码完全一致
# classifier = EarlyExitClassifier(input_dim=27, hidden_dim=64)
# state_dict = None
# # 尝试多种加载方式
# if os.path.isdir(classifier_path):
# # 方式1: safetensors 格式(Trainer 默认)
# safetensors_file = os.path.join(classifier_path, "model.safetensors")
# if os.path.exists(safetensors_file):
# from safetensors.torch import load_file
# state_dict = load_file(safetensors_file)
# print_master(f"[EE] Loaded from model.safetensors")
# else:
# # 方式2: pytorch_model.bin
# pt_file = os.path.join(classifier_path, "pytorch_model.bin")
# if os.path.exists(pt_file):
# state_dict = torch.load(pt_file, map_location=training_args.device)
# print_master(f"[EE] Loaded from pytorch_model.bin")
# else:
# # 方式3: 独立的 .pt 文件
# layer_idx = ee_cfg.get('layer', 12)
# pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt")
# if os.path.exists(pt_file):
# state_dict = torch.load(pt_file, map_location=training_args.device)
# print_master(f"[EE] Loaded from {os.path.basename(pt_file)}")
# else:
# raise FileNotFoundError(f"Cannot find classifier weights in {classifier_path}")
# elif os.path.isfile(classifier_path):
# # 方式4: 直接指定文件
# if classifier_path.endswith('.safetensors'):
# from safetensors.torch import load_file
# state_dict = load_file(classifier_path)
# print_master(f"[EE] Loaded from .safetensors file")
# else:
# state_dict = torch.load(classifier_path, map_location=training_args.device)
# print_master(f"[EE] Loaded from .pt file")
# else:
# raise FileNotFoundError(f"Classifier path not found: {classifier_path}")
# classifier.load_state_dict(state_dict)
# classifier.to(training_args.device)
# classifier.eval()
# print_master(f"[EE] Classifier loaded successfully. Threshold={ee_cfg['threshold']}")
# with open(data_args.dataset_config, 'r') as yaml_file:
# dataset_configs = yaml.safe_load(yaml_file)
# for dataset_name, task_config in dataset_configs.items():
# if dist.is_initialized(): dist.barrier()
# print_master(f"\n--- Evaluating {dataset_name} ---")
# if data_args.data_basedir:
# for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
# if task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
# full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
# full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
# mid_layer = int(ee_cfg["layer"])
# cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}")
# cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast")
# # 【关键修复】强制重新生成candidates以确保normalize设置一致
# force_regenerate = False # 已经重新生成过,设为False避免每次都重新计算
# # 预计算 Candidates
# if force_regenerate or (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)):
# print_master(f"[INFO] Regenerating candidates with normalize={model.normalize}...")
# eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
# eval_cand_loader = DataLoader(full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers)
# cand_mid, cand_last, cand_ids = encode_candidates_both_layers(model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer)
# if local_rank == 0:
# with open(cand_mid_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f)
# with open(cand_last_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f)
# if dist.is_initialized(): dist.barrier()
# if local_rank == 0:
# with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f)
# with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f)
# # 【关键修复】从task_config读取eval_type,决定global/local ranking
# rank_global = task_config.get("eval_type", "global") == "global"
# print_master(f"[{dataset_name}] Eval type: {'global' if rank_global else 'local'} ranking")
# run_early_exit_queries(
# model, classifier, processor, model_args, data_args, training_args,
# full_eval_qry_dataset, cand_mid_dict, cand_last_dict, ee_cfg, dataset_name, data_args.encode_output_path,
# global_ranking=rank_global
# )
# if dist.is_initialized(): dist.barrier()
# if __name__ == '__main__':
# main()
# ###################################################################################
# # 添加早停率上下界
# 调用格式:
# export EE_EXIT_LOWER=0.05
# export EE_EXIT_UPPER=0.5
# import datetime
# import logging
# import json
# import random
# import time
# import numpy as np
# import os
# import pickle
# import sys
# import torch
# import torch.distributed as dist
# import torch.nn.functional as F
# import yaml
# import transformers
# import math
# from torch.utils.data import DataLoader
# from tqdm import tqdm
# from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
# from datasets import Dataset, concatenate_datasets
# from datasets.distributed import split_dataset_by_node
# from src.arguments import ModelArguments, DataArguments, TrainingArguments
# from src.data.collator.eval_collator import MultimodalEvalDataCollator
# from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
# from src.eval_utils.metrics import RankingMetrics
# from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel
# from src.model.processor import get_backbone_name, load_processor, COLPALI
# from src.utils import batch_to_device, print_rank, print_master
# # 引入分类器
# from src.classifier_utils_V2 import EarlyExitClassifier
# logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
# logger = logging.getLogger(__name__)
# # ===========================
# # Helper Functions (AOP Config)
# # ===========================
# def _parse_bool(v: str, default=False):
# if v is None: return default
# v = v.strip().lower()
# return v in {"1","true","yes","y","t","on"}
# def _parse_float(v: str, default=None):
# try: return float(v) if v is not None else default
# except: return default
# def _parse_int(v: str, default=None):
# try: return int(v) if v is not None else default
# except: return default
# def get_env_aop_config():
# enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
# apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower()
# layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
# mode = os.environ.get("AOP_MODE", "delta").strip().lower()
# # Parameters
# delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
# khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
# keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
# min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
# use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
# # Specific
# prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
# prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)
# keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
# keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
# selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
# attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower()
# if layer_idx is None and enabled:
# enabled = False
# return {
# "enabled": enabled, "apply_to": apply_to, "layer_idx": layer_idx, "mode": mode,
# "delta": delta, "K_hat": khat, "keep_ratio": keep_ratio, "min_keep": min_keep,
# "use_bias": use_bias, "prune_vision": prune_vision, "prune_text": prune_text,
# "keep_ratio_vision": keep_ratio_v, "keep_ratio_text": keep_ratio_t,
# "selection": selection, "attn_agg": attn_agg,
# "margin_mid": None # Simplified
# }
# def get_env_ee_config():
# ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"}
# layer = int(os.environ.get("EE_LAYER", "12"))
# method = os.environ.get("EE_METHOD", "classifier").strip().lower()
# threshold = float(os.environ.get("EE_THRESHOLD", "0.8"))
# classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "")
# # ===== 新增:早停率上下界(0~1) =====
# exit_lower = float(os.environ.get("EE_EXIT_LOWER", "0.0")) # 下界,默认 0
# exit_upper = float(os.environ.get("EE_EXIT_UPPER", "1.0")) # 上界,默认 1
# # 合法化:保证 0 <= lower <= upper <= 1
# exit_lower = max(0.0, min(exit_lower, 1.0))
# exit_upper = max(exit_lower, min(exit_upper, 1.0))
# return dict(
# enabled=ee_enabled,
# layer=layer,
# method=method,
# threshold=threshold,
# classifier_path=classifier_path,
# exit_lower=exit_lower,
# exit_upper=exit_upper,
# )
# def pad_dataset_to_divisible(dataset, world_size):
# num_samples = len(dataset)
# if num_samples % world_size == 0: return dataset, num_samples
# num_to_add = world_size - (num_samples % world_size)
# padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
# return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add
# # ===========================
# # Core Inference Function
# # ===========================
# def run_early_exit_queries(
# model: MMEBModel,
# classifier: EarlyExitClassifier,
# processor,
# model_args: ModelArguments,
# data_args: DataArguments,
# training_args: TrainingArguments,
# qry_dataset: Dataset,
# cand_mid_dict: dict,
# cand_last_dict: dict,
# ee_cfg: dict,
# dataset_name: str,
# out_dir: str,
# global_ranking: bool = True,
# ):
# device = training_args.device
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# is_main = (not dist.is_initialized()) or (local_rank == 0)
# # ==========================
# # 全局早停率上下界(按 bench)
# # ==========================
# total_samples = len(qry_dataset)
# exit_lower = float(ee_cfg.get("exit_lower", 0.0))
# exit_upper = float(ee_cfg.get("exit_upper", 1.0))
# # 是否真正启用约束:如果是默认 [0,1] 就视为关闭,保持兼容
# use_exit_bounds = not (abs(exit_lower - 0.0) < 1e-6 and abs(exit_upper - 1.0) < 1e-6)
# # 映射成「最少/最多早停的样本数」
# min_exit = int(math.ceil(exit_lower * total_samples))
# max_exit = int(math.floor(exit_upper * total_samples))
# # 安全裁剪
# min_exit = max(0, min(min_exit, total_samples))
# max_exit = max(min_exit, min(max_exit, total_samples))
# if is_main:
# print_master(
# f"[EE] Global exit-rate bounds for dataset '{dataset_name}': "
# f"lower={exit_lower:.3f} ({min_exit}/{total_samples}), "
# f"upper={exit_upper:.3f} ({max_exit}/{total_samples}), "
# f"enabled={use_exit_bounds}"
# )
# # ==========================
# # Profiling 配置
# # ==========================
# profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in {
# "1", "true", "yes", "on", "y", "t"
# }
# topk_emb = int(os.environ.get("EE_TOPK_EMB", "5"))
# # 运行时间统计(按样本数加权)
# timing_stats = {
# "mid_time_sum": 0.0, # encoder 到 mid 的总时间 * 样本数
# "mid_num": 0, # 样本数
# "tail_time_sum": 0.0, # mid->last 续跑部分的总时间 * 样本数
# "tail_num": 0, # 续跑样本数
# }
# # embedding + 相似度记录(仅 rank0 保存)
# analysis_records = [] if (profile_enabled and is_main) else None
# # 1. 准备 Candidates
# cand_ids = list(cand_mid_dict.keys())
# cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32)
# cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32)
# # 新增:作为检索打分使用的 FP32 Numpy 向量
# cand_mid_np = cand_mid # [Nc, D], float32
# cand_last_np = cand_last # [Nc, D], float32
# cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16)
# cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16)
# collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
# loader = DataLoader(
# qry_dataset,
# batch_size=training_args.per_device_eval_batch_size,
# collate_fn=collator,
# num_workers=training_args.dataloader_num_workers
# )
# pred_dicts = []
# stats = {"exit": 0, "total": 0}
# threshold = float(ee_cfg["threshold"])
# method = ee_cfg["method"]
# target_layer_idx = int(ee_cfg["layer"])
# # 结果顺序
# results_dict = {}
# global_sample_idx = 0
# # Local vs Global ranking
# use_local = (not global_ranking)
# if use_local:
# print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)")
# cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)}
# else:
# print_master(f"[INFO] Using GLOBAL ranking (full library)")
# # --- AOP 配置初始化 ---
# aop_cfg = getattr(model.encoder, "aop_prune_config", None)
# _orig_enabled = None
# side_enable = True
# if isinstance(aop_cfg, dict) and aop_cfg:
# _orig_enabled = aop_cfg.get("enabled", False)
# apply_to = aop_cfg.get("apply_to", "qry")
# side_enable = (apply_to == "both") or (apply_to == "qry")
# model.eval()
# if classifier:
# classifier.eval()
# classifier.to(device)
# start_time = time.time()
# for inputs, infos in tqdm(
# loader,
# desc=f"[EE+AOP] {dataset_name} (tau={threshold})",
# disable=local_rank > 0,
# ):
# inputs = batch_to_device(inputs, device)
# B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1
# batch_start_idx = global_sample_idx
# global_sample_idx += B
# stats["total"] += B
# # ---------------------------------------------------
# # 1. 前半程: Run to Mid Layer (含 AOP 动态控制)
# # ---------------------------------------------------
# orig_cfg = None
# if isinstance(aop_cfg, dict) and aop_cfg:
# orig_cfg = dict(aop_cfg)
# aop_layer = aop_cfg.get("layer_idx", None)
# aop_on_mid = bool(
# _orig_enabled
# and side_enable
# and (aop_layer is not None)
# and (aop_layer < target_layer_idx)
# )
# aop_cfg_mid = dict(aop_cfg)
# aop_cfg_mid["enabled"] = aop_on_mid
# setattr(model.encoder, "aop_prune_config", aop_cfg_mid)
# # 计时:encoder 到 mid 层
# if profile_enabled:
# torch.cuda.synchronize()
# t0_mid = time.perf_counter()
# with torch.no_grad(), torch.autocast(
# device_type="cuda", dtype=torch.bfloat16, enabled=True
# ):
# out_mid = model.encoder(
# **inputs,
# return_dict=True,
# output_hidden_states=False,
# stop_at_layer=target_layer_idx,
# compute_lm_head=False,
# )
# if profile_enabled:
# torch.cuda.synchronize()
# t1_mid = time.perf_counter()
# timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B
# timing_stats["mid_num"] += B
# # 恢复 AOP 配置
# if isinstance(orig_cfg, dict):
# setattr(model.encoder, "aop_prune_config", orig_cfg)
# # Hidden State & Mask
# hs_mid = getattr(out_mid, "last_hidden_state", None)
# if hs_mid is None:
# hs_mid = out_mid.hidden_states[-1]
# am_mid = getattr(out_mid, "attention_mask", None)
# if am_mid is None:
# am_mid = inputs.get("attention_mask", None)
# reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(dtype=torch.bfloat16)
# # 如果要记录最后一层 embedding,需要在 profiling 模式下额外跑一次全 forward
# reps_last_full = None
# if profile_enabled:
# with torch.no_grad(), torch.autocast(
# device_type="cuda", dtype=torch.bfloat16, enabled=True
# ):
# out_full = model.encoder(
# **inputs,
# return_dict=True,
# output_hidden_states=False,
# stop_at_layer=None,
# compute_lm_head=False,
# )
# hs_last_full = getattr(out_full, "last_hidden_state", None)
# if hs_last_full is None:
# hs_last_full = out_full.hidden_states[-1]
# am_last_full = getattr(out_full, "attention_mask", None)
# if am_last_full is None:
# am_last_full = inputs.get("attention_mask", None)
# reps_last_full = (
# model._pooling(hs_last_full, am_last_full)
# .detach()
# .to(dtype=torch.bfloat16)
# )
# # ---------------------------------------------------
# # 2. 特征工程 + gating
# # ---------------------------------------------------
# exit_mask = np.zeros(B, dtype=bool)
# p_need_last_batch = None # profiling / debug
# orig_exit_mask = None # 原始按 threshold 判的结果(不带全局约束)
# if method == "classifier" and classifier is not None:
# with torch.no_grad():
# # 【关键修复】特征提取:qry_mid × cand_mid(表征空间对齐)
# cos_mid = reps_mid_t @ cand_mid_t.T # [B, N]
# backbone_ptr = (
# model.module if hasattr(model, "module") else model
# )
# temp = getattr(backbone_ptr, "temperature", 0.02)
# scores_mid = cos_mid / temp
# probs_mid = torch.softmax(scores_mid, dim=1) # [B, N]
# # ==== 以下是你原来的 27 维特征工程代码,原样保留 ====
# diag_cos = cos_mid.max(dim=1)[0]
# sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True)
# s2_cos = (
# sorted_cos[:, 1]
# if sorted_cos.size(1) > 1
# else sorted_cos[:, 0]
# )
# margin_mid = diag_cos - s2_cos
# margin_mean = margin_mid.mean()
# margin_std = margin_mid.std(unbiased=False) + 1e-6
# z_margin_mid = (margin_mid - margin_mean) / margin_std
# margin_median = margin_mid.median()
# mad = (margin_mid - margin_median).abs().median() + 1e-6
# mad_margin_mid = (margin_mid - margin_median) / mad
# p1_mid = probs_mid.max(dim=1)[0]
# H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1)
# gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1)
# TOPK = min(16, probs_mid.size(1))
# topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1)
# topk_mean = topk_vals.mean(dim=1)
# topk_std = topk_vals.std(dim=1, unbiased=False)
# topk_cv = topk_std / (topk_mean + 1e-6)
# centered = topk_vals - topk_mean.unsqueeze(1)
# var = (centered ** 2).mean(dim=1) + 1e-6
# m4 = (centered ** 4).mean(dim=1)
# topk_kurt = m4 / (var ** 2)
# topk_med = topk_vals.median(dim=1).values
# row_mean_cos = cos_mid.mean(dim=1)
# row_med_cos = cos_mid.median(dim=1).values
# s1_over_mean = diag_cos - row_mean_cos
# s1_over_med = diag_cos - row_med_cos
# sorted_probs, _ = torch.sort(
# probs_mid, dim=1, descending=True
# )
# p1 = sorted_probs[:, 0]
# p2 = (
# sorted_probs[:, 1]
# if sorted_probs.size(1) > 1
# else sorted_probs[:, 0]
# )
# shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(
# dim=1
# )
# shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1)
# R = min(10, sorted_probs.size(1))
# x = torch.arange(
# R, device=device, dtype=sorted_probs.dtype
# )
# x_centered = x - x.mean()
# denom = (x_centered ** 2).sum()
# y = torch.log(sorted_probs[:, :R] + 1e-6)
# slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom
# row_mean_p = probs_mid.mean(dim=1)
# row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6
# z1 = (p1_mid - row_mean_p) / row_std_p
# center_p = probs_mid - row_mean_p.unsqueeze(1)
# m3 = (center_p ** 3).mean(dim=1)
# skew = m3 / (row_std_p ** 3 + 1e-6)
# s1_over_sk = p1_mid - skew
# TAIL_K = min(10, sorted_probs.size(1))
# tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1)
# HEAD_K = min(5, sorted_probs.size(1))
# head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1)
# mask_ratio = torch.zeros_like(diag_cos)
# mask_len = torch.zeros_like(diag_cos)
# mask_runs = torch.zeros_like(diag_cos)
# scalar_inputs = torch.stack(
# [
# diag_cos,
# s2_cos,
# margin_mid,
# z_margin_mid,
# mad_margin_mid,
# p1_mid,
# H_mid,
# gini_mid,
# topk_mean,
# topk_std,
# topk_cv,
# topk_kurt,
# topk_med,
# s1_over_mean,
# s1_over_med,
# p1,
# p2,
# shape_H,
# shape_gini,
# slope,
# z1,
# s1_over_sk,
# tail_mean,
# head5_mean,
# mask_ratio,
# mask_len,
# mask_runs,
# ],
# dim=1,
# )
# modality_idx = torch.zeros(
# B, dtype=torch.long, device=device
# )
# if "pixel_values" in inputs and inputs["pixel_values"] is not None:
# pv = inputs["pixel_values"]
# if isinstance(pv, list):
# for i, item in enumerate(pv):
# if item is not None:
# modality_idx[i] = 1
# elif isinstance(pv, torch.Tensor) and pv.numel() > 0:
# modality_idx.fill_(1)
# logits = classifier(scalar_inputs, modality_idx)
# p_need_last = torch.sigmoid(logits) # [B,1]
# p_need_last_batch = p_need_last.squeeze(1) # [B]
# # 原始按 threshold 的判定(仅作为参考)
# should_exit = p_need_last_batch < threshold
# orig_exit_mask = should_exit.cpu().numpy()
# # ====== 应用全局早停率上下界(如果启用) ======
# if method == "classifier" and classifier is not None and orig_exit_mask is not None:
# if not use_exit_bounds:
# # 不启用上下界:行为与旧版本完全一致
# exit_mask = orig_exit_mask
# else:
# # 当前之前已经早停的数量
# E_prev = stats["exit"]
# # 当前 batch 前已经处理的样本数(全局 index)
# i_prev = batch_start_idx
# B_cur = B
# N_total = total_samples
# # 最多还能早停多少个(不超过上界)
# max_exit_remaining = max_exit - E_prev
# if max_exit_remaining <= 0:
# allowed_max = 0
# else:
# allowed_max = min(B_cur, max_exit_remaining)
# # 本 batch 之后还剩多少个样本没处理
# remaining_after = N_total - (i_prev + B_cur)
# # 为了将来仍有机会达到下界,本 batch 至少需要早停多少个
# required_min = min(
# B_cur,
# max(0, min_exit - E_prev - remaining_after)
# )
# # 分类器在当前 batch 原始建议早停个数(按 threshold)
# n0 = int(orig_exit_mask.sum())
# if allowed_max <= 0:
# n_exit = 0
# else:
# # 在 [required_min, allowed_max] 区间内,尽量贴近原始 n0
# n_exit = max(required_min, min(n0, allowed_max))
# # 最终 exit_mask:在当前 batch 中选 p_need_last 最小的 n_exit 个样本早停
# exit_mask = np.zeros(B_cur, dtype=bool)
# if n_exit > 0:
# p_np = p_need_last_batch.detach().cpu().numpy() # [B]
# order = np.argsort(p_np) # 升序(越小越容易早停)
# chosen = order[:n_exit]
# exit_mask[chosen] = True
# # Debug:只在前几个 batch 打印
# if stats["total"] <= B * 3 and is_main:
# raw_rate = orig_exit_mask.mean()
# bounded_rate = exit_mask.mean()
# print_master(
# f"[EE Debug] Batch {stats['total']//B}: "
# f"p_need_last mean={p_need_last_batch.mean().item():.4f}, "
# f"std={p_need_last_batch.std().item():.4f}, "
# f"raw_exit_rate={raw_rate:.2%}, "
# f"bounded_exit_rate={bounded_rate:.2%}, "
# f"required_min={required_min}, allowed_max={allowed_max}, "
# f"n0={n0}, n_exit={n_exit}"
# )
# # 非 classifier 或 classifier 不存在:保持 exit_mask=全 False(即无早停)
# stats["exit"] += int(exit_mask.sum())
# # ---------------------------------------------------
# # 3. 分支执行
# # ---------------------------------------------------
# exit_indices = np.where(exit_mask)[0]
# cont_indices = np.where(~exit_mask)[0]
# # A. 早停样本:用 mid→last 检索
# if len(exit_indices) > 0:
# reps_exit = reps_mid_t[exit_indices]
# if use_local:
# for i, idx in enumerate(exit_indices):
# cand_local = infos[idx].get("cand_names", [])
# if not cand_local:
# cids = []
# else:
# rows = [
# cand_id2row.get(str(cid), -1)
# for cid in cand_local
# ]
# rows = [r for r in rows if r >= 0]
# if len(rows) == 0:
# cids = []
# else:
# # CPU FP32:单个 query mid × 局部 cand_mid
# cmat_np = cand_mid_np[rows] # [Nc_local, D]
# qry_np = reps_exit[i].detach().float().cpu().numpy() # [D]
# scores_vec = np.dot(qry_np, cmat_np.T) # [Nc_local]
# top_k = min(200, len(rows))
# order_local = np.argsort(-scores_vec)[:top_k]
# cids = [str(cand_local[o]) for o in order_local]
# label = (
# infos[idx].get("label_name")
# or infos[idx].get("label")
# or infos[idx].get("rel_docids")
# )
# if not isinstance(label, list):
# label = [label]
# rel_scores = infos[idx].get("rel_scores", None)
# global_idx = batch_start_idx + idx
# results_dict[global_idx] = {
# "prediction": cids,
# "label": label,
# "rel_scores": rel_scores,
# }
# else:
# # 使用 CPU FP32 Numpy:reps_exit_np × cand_mid_np.T
# reps_exit_np = reps_exit.detach().float().cpu().numpy() # [B_exit, D]
# scores_full = np.dot(reps_exit_np, cand_mid_np.T) # [B_exit, Nc]
# top_k = min(200, len(cand_ids))
# # 全排序后截到 top_k
# topk_inds = np.argsort(-scores_full, axis=1)[:, :top_k] # [B_exit, top_k]
# for i, idx in enumerate(exit_indices):
# cids = [cand_ids[k] for k in topk_inds[i]]
# label = (
# infos[idx].get("label_name")
# or infos[idx].get("label")
# or infos[idx].get("rel_docids")
# )
# if not isinstance(label, list):
# label = [label]
# rel_scores = infos[idx].get("rel_scores", None)
# global_idx = batch_start_idx + idx
# results_dict[global_idx] = {
# "prediction": cids,
# "label": label,
# "rel_scores": rel_scores,
# }
# # B. 续跑样本:从 mid 继续到 last
# if len(cont_indices) > 0:
# interm = getattr(out_mid, "intermediate_state", None)
# hs = interm["hidden_states"].detach()
# am = interm["attention_mask"].detach()
# pos = interm["position_ids"].detach()
# vm = interm.get("vision_mask", None)
# tm = interm.get("text_mask", None)
# next_layer = int(interm["next_layer_idx"])
# resume_state_subset = {
# "hidden_states": hs[cont_indices],
# "attention_mask": am[cont_indices],
# "position_ids": pos[:, cont_indices, :],
# "vision_mask": vm[cont_indices] if vm is not None else None,
# "text_mask": tm[cont_indices] if tm is not None else None,
# "next_layer_idx": next_layer,
# }
# if isinstance(aop_cfg, dict) and aop_cfg:
# aop_resume = dict(aop_cfg)
# aop_resume["enabled"] = bool(_orig_enabled and side_enable)
# setattr(model.encoder, "aop_prune_config", aop_resume)
# # 计时:mid -> last
# if profile_enabled:
# torch.cuda.synchronize()
# t0_tail = time.perf_counter()
# with torch.no_grad(), torch.autocast(
# device_type="cuda", dtype=torch.bfloat16, enabled=True
# ):
# out_last = model.encoder(
# return_dict=True,
# output_hidden_states=False,
# stop_at_layer=None,
# resume_state=resume_state_subset,
# compute_lm_head=False,
# )
# if profile_enabled:
# torch.cuda.synchronize()
# t1_tail = time.perf_counter()
# timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len(
# cont_indices
# )
# timing_stats["tail_num"] += len(cont_indices)
# hs_last = out_last.last_hidden_state
# if hs_last is None:
# hs_last = out_last.hidden_states[-1]
# am_last = getattr(out_last, "attention_mask", None)
# if am_last is None:
# am_last = resume_state_subset["attention_mask"]
# reps_last_t = (
# model._pooling(hs_last, am_last)
# .detach()
# .to(dtype=torch.bfloat16)
# )
# if use_local:
# for i, idx_global in enumerate(cont_indices):
# cand_local = infos[idx_global].get("cand_names", [])
# if not cand_local:
# cids = []
# else:
# rows = [
# cand_id2row.get(str(cid), -1)
# for cid in cand_local
# ]
# rows = [r for r in rows if r >= 0]
# if len(rows) == 0:
# cids = []
# else:
# # CPU FP32:单个 query last × 局部 cand_last
# cmat_last_np = cand_last_np[rows] # [Nc_local, D]
# qry_last_np = reps_last_t[i].detach().float().cpu().numpy()
# scores_vec = np.dot(qry_last_np, cmat_last_np.T) # [Nc_local]
# top_k = min(200, len(rows))
# order_local = np.argsort(-scores_vec)[:top_k]
# cids = [str(cand_local[o]) for o in order_local]
# label = (
# infos[idx_global].get("label_name")
# or infos[idx_global].get("label")
# or infos[idx_global].get("rel_docids")
# )
# if not isinstance(label, list):
# label = [label]
# rel_scores = infos[idx_global].get("rel_scores", None)
# global_idx = batch_start_idx + idx_global
# results_dict[global_idx] = {
# "prediction": cids,
# "label": label,
# "rel_scores": rel_scores,
# }
# else:
# # CPU FP32:所有续跑 query last × 全量 cand_last
# reps_last_np = reps_last_t.detach().float().cpu().numpy() # [B_cont, D]
# scores_last = np.dot(reps_last_np, cand_last_np.T) # [B_cont, Nc]
# top_k = min(200, len(cand_ids))
# topk_inds = np.argsort(-scores_last, axis=1)[:, :top_k] # [B_cont, top_k]
# for i, idx_global in enumerate(cont_indices):
# cids = [cand_ids[k] for k in topk_inds[i]]
# label = (
# infos[idx_global].get("label_name")
# or infos[idx_global].get("label")
# or infos[idx_global].get("rel_docids")
# )
# if not isinstance(label, list):
# label = [label]
# rel_scores = infos[idx_global].get("rel_scores", None)
# global_idx = batch_start_idx + idx_global
# results_dict[global_idx] = {
# "prediction": cids,
# "label": label,
# "rel_scores": rel_scores,
# }
# # ---------------------------------------------------
# # 4. Profiling: 保存 per-query 的 topK embedding & 相似度
# # ---------------------------------------------------
# if profile_enabled and is_main:
# K = min(topk_emb, cand_mid_t.size(0))
# # 转到 float32 + CPU 便于写盘
# q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D]
# q_last_cpu = (
# reps_last_full.detach().float().cpu()
# if reps_last_full is not None
# else None
# ) # [B, D]
# cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D]
# cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D]
# # 【关键修复】mid2mid 相似度(表征空间对齐)
# scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc]
# topk_mid_vals, topk_mid_inds = torch.topk(
# scores_mid_full, k=K, dim=1
# )
# # last2last 相似度(如果有)
# if q_last_cpu is not None:
# scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc]
# topk_last_vals, topk_last_inds = torch.topk(
# scores_last_full, k=K, dim=1
# )
# else:
# topk_last_vals = None
# topk_last_inds = None
# for i in range(B):
# qid = batch_start_idx + i
# rec = {
# "qid": int(qid),
# "early_exit": bool(exit_mask[i]),
# }
# if p_need_last_batch is not None:
# rec["p_need_last"] = float(p_need_last_batch[i].item())
# # 【关键修复】mid2mid TopK(表征空间对齐)
# mid_inds = topk_mid_inds[i].tolist()
# mid_scores = topk_mid_vals[i].tolist()
# rec["mid_topk_scores"] = mid_scores
# rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds]
# rec["mid_q_emb"] = q_mid_cpu[i].tolist()
# rec["mid_cand_embs"] = cand_mid_cpu[mid_inds].tolist()
# # last2last TopK(如果有)
# if topk_last_inds is not None:
# last_inds = topk_last_inds[i].tolist()
# last_scores = topk_last_vals[i].tolist()
# rec["last_topk_scores"] = last_scores
# rec["last_topk_cand_ids"] = [
# cand_ids[j] for j in last_inds
# ]
# rec["last_q_emb"] = (
# q_last_cpu[i].tolist() if q_last_cpu is not None else None
# )
# rec["last_cand_embs"] = cand_last_cpu[last_inds].tolist()
# else:
# rec["last_topk_scores"] = None
# rec["last_topk_cand_ids"] = None
# rec["last_q_emb"] = None
# rec["last_cand_embs"] = None
# analysis_records.append(rec)
# # =====================================================
# # 5. 收集 & 保存结果
# # =====================================================
# for idx in sorted(results_dict.keys()):
# pred_dicts.append(results_dict[idx])
# print_master(
# f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} "
# f"({stats['exit']/stats['total']:.2%})"
# )
# metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
# score = RankingMetrics(metrics_to_report).evaluate(pred_dicts)
# if is_main:
# os.makedirs(out_dir, exist_ok=True)
# with open(
# os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"),
# "w",
# ) as f:
# json.dump(score, f, indent=4)
# # 保存 profiling 信息
# if profile_enabled:
# prof_dir = os.path.join(out_dir, "profiling")
# os.makedirs(prof_dir, exist_ok=True)
# # 时间统计:转换成均值(秒/样本)
# mid_avg = (
# timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"])
# )
# tail_avg = (
# timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"])
# )
# timing_out = {
# "mid_time_sum": timing_stats["mid_time_sum"],
# "mid_num": timing_stats["mid_num"],
# "tail_time_sum": timing_stats["tail_time_sum"],
# "tail_num": timing_stats["tail_num"],
# "avg_mid_time_per_query_sec": mid_avg,
# "avg_tail_time_per_cont_query_sec": tail_avg,
# "num_exit": int(stats["exit"]),
# "num_total": int(stats["total"]),
# }
# with open(
# os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w"
# ) as f:
# json.dump(timing_out, f, indent=2)
# # embedding + 相似度记录(JSONL)
# embed_path = os.path.join(
# prof_dir, f"{dataset_name}_embeds.jsonl"
# )
# with open(embed_path, "w") as f:
# for rec in analysis_records:
# f.write(json.dumps(rec) + "\n")
# print_master(
# f"[PROFILE] Saved timing to {prof_dir}, "
# f"embeddings to {embed_path}"
# )
# elapsed = time.time() - start_time
# return score, elapsed
# # ===========================
# # Helper Functions (Pre-Computation)
# # ===========================
# def encode_candidates_both_layers(
# model: MMEBModel, loader: DataLoader, training_args: TrainingArguments,
# model_args: ModelArguments, full_dataset: Dataset, mid_layer: int,
# ):
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# model.eval()
# all_mid, all_last, all_ids = [], [], []
# with torch.no_grad():
# for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0):
# inputs = batch_to_device(inputs, training_args.device)
# # 强制关闭 Cand 侧的 AOP
# aop_cfg = getattr(model.encoder, "aop_prune_config", None)
# _orig = None
# if isinstance(aop_cfg, dict):
# _orig = aop_cfg.get("enabled", False)
# aop_cfg["enabled"] = False
# setattr(model.encoder, "aop_prune_config", aop_cfg)
# with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
# out = model.encoder(**inputs, return_dict=True, output_hidden_states=True, stop_at_layer=None)
# if isinstance(aop_cfg, dict) and _orig is not None:
# aop_cfg["enabled"] = _orig # Restore
# mid_hs = out.hidden_states[mid_layer]
# last_hs = out.hidden_states[-1]
# am = inputs.get("attention_mask", None)
# if am is not None and am.device != mid_hs.device: am = am.to(mid_hs.device)
# reps_mid = model._pooling(mid_hs, am)
# reps_last = model._pooling(last_hs, am)
# all_mid.append(reps_mid.detach().float().cpu())
# all_last.append(reps_last.detach().float().cpu())
# all_ids.extend([info["cand_name"] for info in dataset_info])
# if not all_mid: return np.array([]), np.array([]), []
# return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids
# # ===========================
# # Main
# # ===========================
# def main():
# if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
# dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
# model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# ee_cfg = get_env_ee_config()
# hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# if not getattr(model_args, "model_backbone", None):
# model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
# setattr(model_args, 'model_backbone', model_backbone)
# if local_rank == 0:
# processor = load_processor(model_args, data_args)
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# if torch.distributed.is_initialized(): torch.distributed.barrier()
# if local_rank != 0:
# processor = load_processor(model_args, data_args)
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# model.eval()
# model = model.to(training_args.device, dtype=torch.bfloat16)
# # 调试:检查模型配置
# print_master(f"[DEBUG] Model normalize={model.normalize}, pooling={model.pooling}, temperature={getattr(model, 'temperature', 'N/A')}")
# # AOP 配置注入
# aop_cfg = get_env_aop_config()
# if aop_cfg["enabled"]:
# setattr(model.encoder, "aop_prune_config", aop_cfg)
# model.set_inference_layers(qry_layers=None, tgt_layers=None)
# print_master(f"[AOP] Enabled: Layer={aop_cfg['layer_idx']}, Ratio={aop_cfg['keep_ratio']}")
# # 加载分类器
# classifier = None
# if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]:
# classifier_path = ee_cfg['classifier_path']
# print_master(f"[EE] Loading Classifier from {classifier_path}...")
# # 【关键】使用27维特征,与训练代码完全一致
# classifier = EarlyExitClassifier(input_dim=27, hidden_dim=64)
# state_dict = None
# # 尝试多种加载方式
# if os.path.isdir(classifier_path):
# # 方式1: safetensors 格式(Trainer 默认)
# safetensors_file = os.path.join(classifier_path, "model.safetensors")
# if os.path.exists(safetensors_file):
# from safetensors.torch import load_file
# state_dict = load_file(safetensors_file)
# print_master(f"[EE] Loaded from model.safetensors")
# else:
# # 方式2: pytorch_model.bin
# pt_file = os.path.join(classifier_path, "pytorch_model.bin")
# if os.path.exists(pt_file):
# state_dict = torch.load(pt_file, map_location=training_args.device)
# print_master(f"[EE] Loaded from pytorch_model.bin")
# else:
# # 方式3: 独立的 .pt 文件
# layer_idx = ee_cfg.get('layer', 12)
# pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt")
# if os.path.exists(pt_file):
# state_dict = torch.load(pt_file, map_location=training_args.device)
# print_master(f"[EE] Loaded from {os.path.basename(pt_file)}")
# else:
# raise FileNotFoundError(f"Cannot find classifier weights in {classifier_path}")
# elif os.path.isfile(classifier_path):
# # 方式4: 直接指定文件
# if classifier_path.endswith('.safetensors'):
# from safetensors.torch import load_file
# state_dict = load_file(classifier_path)
# print_master(f"[EE] Loaded from .safetensors file")
# else:
# state_dict = torch.load(classifier_path, map_location=training_args.device)
# print_master(f"[EE] Loaded from .pt file")
# else:
# raise FileNotFoundError(f"Classifier path not found: {classifier_path}")
# classifier.load_state_dict(state_dict)
# classifier.to(training_args.device)
# classifier.eval()
# print_master(f"[EE] Classifier loaded successfully. Threshold={ee_cfg['threshold']}")
# with open(data_args.dataset_config, 'r') as yaml_file:
# dataset_configs = yaml.safe_load(yaml_file)
# for dataset_name, task_config in dataset_configs.items():
# if dist.is_initialized(): dist.barrier()
# print_master(f"\n--- Evaluating {dataset_name} ---")
# if data_args.data_basedir:
# for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
# if task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
# full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
# full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
# mid_layer = int(ee_cfg["layer"])
# cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}")
# cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast")
# # 【关键修复】强制重新生成candidates以确保normalize设置一致
# force_regenerate = False # 已经重新生成过,设为False避免每次都重新计算
# # 预计算 Candidates
# if force_regenerate or (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)):
# print_master(f"[INFO] Regenerating candidates with normalize={model.normalize}...")
# eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
# eval_cand_loader = DataLoader(full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers)
# cand_mid, cand_last, cand_ids = encode_candidates_both_layers(model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer)
# if local_rank == 0:
# with open(cand_mid_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f)
# with open(cand_last_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f)
# if dist.is_initialized(): dist.barrier()
# if local_rank == 0:
# with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f)
# with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f)
# # 【关键修复】从task_config读取eval_type,决定global/local ranking
# rank_global = task_config.get("eval_type", "global") == "global"
# print_master(f"[{dataset_name}] Eval type: {'global' if rank_global else 'local'} ranking")
# run_early_exit_queries(
# model, classifier, processor, model_args, data_args, training_args,
# full_eval_qry_dataset, cand_mid_dict, cand_last_dict, ee_cfg, dataset_name, data_args.encode_output_path,
# global_ranking=rank_global
# )
# if dist.is_initialized(): dist.barrier()
# if __name__ == '__main__':
# main()
#######################################################################################################
#解决CPU&GPU计算精度问题
import datetime
import logging
import json
import random
import time
import numpy as np
import os
import pickle
import sys
import torch
import torch.distributed as dist
import torch.nn.functional as F
import yaml
import transformers
import math
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
from datasets import Dataset, concatenate_datasets
from datasets.distributed import split_dataset_by_node
from src.arguments import ModelArguments, DataArguments, TrainingArguments
from src.data.collator.eval_collator import MultimodalEvalDataCollator
from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
from src.eval_utils.metrics import RankingMetrics
from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel
from src.model.processor import get_backbone_name, load_processor, COLPALI
from src.utils import batch_to_device, print_rank, print_master
# 引入分类器
from src.classifier_utils_V2 import EarlyExitClassifier
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
logger = logging.getLogger(__name__)
# ===========================
# Per-dataset thresholds(示例)
# key 必须和 dataset_config YAML 里的 dataset_name 完全一致
# 未出现在此表的 dataset 使用 EE_THRESHOLD 作为默认阈值
# ===========================
PER_DATASET_THRESHOLDS = {
# 举例(自己根据离线 sweep 的结果填写):
"CIRR": 0,
"EDIS": 5e-16,
"FashionIQ": 0,
"MSCOCO_i2t": 1e-16,
"MSCOCO_t2i": 1e-5,
"NIGHTS": 1,
"OVEN": 1e-11,
"VisDial": 1e-11,
"VisualNews_i2t": 5e-28,
"VisualNews_t2i": 1e-16,
"WebQA": 5e-11,
"Wiki-SS-NQ": 1e-16,
}
# ===========================
# Helper Functions (AOP Config)
# ===========================
def _parse_bool(v: str, default=False):
if v is None: return default
v = v.strip().lower()
return v in {"1","true","yes","y","t","on"}
def _parse_float(v: str, default=None):
try: return float(v) if v is not None else default
except: return default
def _parse_int(v: str, default=None):
try: return int(v) if v is not None else default
except: return default
def get_env_aop_config():
enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower()
layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
mode = os.environ.get("AOP_MODE", "delta").strip().lower()
# Parameters
delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
# Specific
prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)
keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower()
if layer_idx is None and enabled:
enabled = False
return {
"enabled": enabled, "apply_to": apply_to, "layer_idx": layer_idx, "mode": mode,
"delta": delta, "K_hat": khat, "keep_ratio": keep_ratio, "min_keep": min_keep,
"use_bias": use_bias, "prune_vision": prune_vision, "prune_text": prune_text,
"keep_ratio_vision": keep_ratio_v, "keep_ratio_text": keep_ratio_t,
"selection": selection, "attn_agg": attn_agg,
"margin_mid": None # Simplified
}
def get_env_ee_config():
ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"}
layer = int(os.environ.get("EE_LAYER", "12"))
method = os.environ.get("EE_METHOD", "classifier").strip().lower()
threshold = float(os.environ.get("EE_THRESHOLD", "0.8"))
classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "")
return dict(enabled=ee_enabled, layer=layer, method=method, threshold=threshold, classifier_path=classifier_path)
def pad_dataset_to_divisible(dataset, world_size):
num_samples = len(dataset)
if num_samples % world_size == 0: return dataset, num_samples
num_to_add = world_size - (num_samples % world_size)
padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add
# ===========================
# Core Inference Function
# ===========================
def run_early_exit_queries(
model: MMEBModel,
classifier: EarlyExitClassifier,
processor,
model_args: ModelArguments,
data_args: DataArguments,
training_args: TrainingArguments,
qry_dataset: Dataset,
cand_mid_dict: dict,
cand_last_dict: dict,
ee_cfg: dict,
dataset_name: str,
out_dir: str,
global_ranking: bool = True,
):
device = training_args.device
local_rank = dist.get_rank() if dist.is_initialized() else 0
is_main = (not dist.is_initialized()) or (local_rank == 0)
# ==========================
# Profiling 配置
# ==========================
profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in {
"1", "true", "yes", "on", "y", "t"
}
topk_emb = int(os.environ.get("EE_TOPK_EMB", "5"))
# 运行时间统计(按样本数加权)
timing_stats = {
"mid_time_sum": 0.0, # encoder 到 mid 的总时间 * 样本数
"mid_num": 0, # 样本数
"tail_time_sum": 0.0, # mid->last 续跑部分的总时间 * 样本数
"tail_num": 0, # 续跑样本数
}
# embedding + 相似度记录(仅 rank0 保存)
analysis_records = [] if (profile_enabled and is_main) else None
# 1. 准备 Candidates
cand_ids = list(cand_mid_dict.keys())
cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32)
cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32)
# 新增:作为检索打分使用的 FP32 Numpy 向量
cand_mid_np = cand_mid # [Nc, D], float32
cand_last_np = cand_last # [Nc, D], float32
cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16)
cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16)
collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
loader = DataLoader(
qry_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=collator,
num_workers=training_args.dataloader_num_workers
)
pred_dicts = []
stats = {"exit": 0, "total": 0}
threshold = float(ee_cfg["threshold"])
method = ee_cfg["method"]
target_layer_idx = int(ee_cfg["layer"])
# 结果顺序
results_dict = {}
global_sample_idx = 0
# Local vs Global ranking
use_local = (not global_ranking)
if use_local:
print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)")
cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)}
else:
print_master(f"[INFO] Using GLOBAL ranking (full library)")
# --- AOP 配置初始化 ---
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig_enabled = None
side_enable = True
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == "qry")
model.eval()
if classifier:
classifier.eval()
classifier.to(device)
start_time = time.time()
for inputs, infos in tqdm(
loader,
desc=f"[EE+AOP] {dataset_name} (tau={threshold})",
disable=local_rank > 0,
):
inputs = batch_to_device(inputs, device)
B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1
batch_start_idx = global_sample_idx
global_sample_idx += B
stats["total"] += B
# ---------------------------------------------------
# 1. 前半程: Run to Mid Layer (含 AOP 动态控制)
# ---------------------------------------------------
orig_cfg = None
if isinstance(aop_cfg, dict) and aop_cfg:
orig_cfg = dict(aop_cfg)
aop_layer = aop_cfg.get("layer_idx", None)
aop_on_mid = bool(
_orig_enabled
and side_enable
and (aop_layer is not None)
and (aop_layer < target_layer_idx)
)
aop_cfg_mid = dict(aop_cfg)
aop_cfg_mid["enabled"] = aop_on_mid
setattr(model.encoder, "aop_prune_config", aop_cfg_mid)
# 计时:encoder 到 mid 层
if profile_enabled:
torch.cuda.synchronize()
t0_mid = time.perf_counter()
with torch.no_grad(), torch.autocast(
device_type="cuda", dtype=torch.bfloat16, enabled=True
):
out_mid = model.encoder(
**inputs,
return_dict=True,
output_hidden_states=False,
stop_at_layer=target_layer_idx,
compute_lm_head=False,
)
if profile_enabled:
torch.cuda.synchronize()
t1_mid = time.perf_counter()
timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B
timing_stats["mid_num"] += B
# 恢复 AOP 配置
if isinstance(orig_cfg, dict):
setattr(model.encoder, "aop_prune_config", orig_cfg)
# Hidden State & Mask
hs_mid = getattr(out_mid, "last_hidden_state", None)
if hs_mid is None:
hs_mid = out_mid.hidden_states[-1]
am_mid = getattr(out_mid, "attention_mask", None)
if am_mid is None:
am_mid = inputs.get("attention_mask", None)
reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(dtype=torch.bfloat16)
# 如果要记录最后一层 embedding,需要在 profiling 模式下额外跑一次全 forward
reps_last_full = None
if profile_enabled:
with torch.no_grad(), torch.autocast(
device_type="cuda", dtype=torch.bfloat16, enabled=True
):
out_full = model.encoder(
**inputs,
return_dict=True,
output_hidden_states=False,
stop_at_layer=None,
compute_lm_head=False,
)
hs_last_full = getattr(out_full, "last_hidden_state", None)
if hs_last_full is None:
hs_last_full = out_full.hidden_states[-1]
am_last_full = getattr(out_full, "attention_mask", None)
if am_last_full is None:
am_last_full = inputs.get("attention_mask", None)
reps_last_full = (
model._pooling(hs_last_full, am_last_full)
.detach()
.to(dtype=torch.bfloat16)
)
# ---------------------------------------------------
# 2. 特征工程 + gating
# ---------------------------------------------------
exit_mask = np.zeros(B, dtype=bool)
p_need_last_batch = None # 仅 profiling 或 debug 时保存
if method == "classifier" and classifier is not None:
with torch.no_grad():
# 【关键修复】特征提取:qry_mid × cand_mid(表征空间对齐)
cos_mid = reps_mid_t @ cand_mid_t.T # [B, N]
backbone_ptr = (
model.module if hasattr(model, "module") else model
)
temp = getattr(backbone_ptr, "temperature", 0.02)
scores_mid = cos_mid / temp
probs_mid = torch.softmax(scores_mid, dim=1) # [B, N]
diag_cos = cos_mid.max(dim=1)[0]
sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True)
s2_cos = (
sorted_cos[:, 1]
if sorted_cos.size(1) > 1
else sorted_cos[:, 0]
)
margin_mid = diag_cos - s2_cos
margin_mean = margin_mid.mean()
margin_std = margin_mid.std(unbiased=False) + 1e-6
z_margin_mid = (margin_mid - margin_mean) / margin_std
margin_median = margin_mid.median()
mad = (margin_mid - margin_median).abs().median() + 1e-6
mad_margin_mid = (margin_mid - margin_median) / mad
p1_mid = probs_mid.max(dim=1)[0]
H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1)
gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1)
TOPK = min(16, probs_mid.size(1))
topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1)
topk_mean = topk_vals.mean(dim=1)
topk_std = topk_vals.std(dim=1, unbiased=False)
topk_cv = topk_std / (topk_mean + 1e-6)
centered = topk_vals - topk_mean.unsqueeze(1)
var = (centered ** 2).mean(dim=1) + 1e-6
m4 = (centered ** 4).mean(dim=1)
topk_kurt = m4 / (var ** 2)
topk_med = topk_vals.median(dim=1).values
row_mean_cos = cos_mid.mean(dim=1)
row_med_cos = cos_mid.median(dim=1).values
s1_over_mean = diag_cos - row_mean_cos
s1_over_med = diag_cos - row_med_cos
sorted_probs, _ = torch.sort(
probs_mid, dim=1, descending=True
)
p1 = sorted_probs[:, 0]
p2 = (
sorted_probs[:, 1]
if sorted_probs.size(1) > 1
else sorted_probs[:, 0]
)
shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(
dim=1
)
shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1)
R = min(10, sorted_probs.size(1))
x = torch.arange(
R, device=device, dtype=sorted_probs.dtype
)
x_centered = x - x.mean()
denom = (x_centered ** 2).sum()
y = torch.log(sorted_probs[:, :R] + 1e-6)
slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom
row_mean_p = probs_mid.mean(dim=1)
row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6
z1 = (p1_mid - row_mean_p) / row_std_p
center_p = probs_mid - row_mean_p.unsqueeze(1)
m3 = (center_p ** 3).mean(dim=1)
skew = m3 / (row_std_p ** 3 + 1e-6)
s1_over_sk = p1_mid - skew
TAIL_K = min(10, sorted_probs.size(1))
tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1)
HEAD_K = min(5, sorted_probs.size(1))
head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1)
mask_ratio = torch.zeros_like(diag_cos)
mask_len = torch.zeros_like(diag_cos)
mask_runs = torch.zeros_like(diag_cos)
scalar_inputs = torch.stack(
[
diag_cos,
s2_cos,
margin_mid,
z_margin_mid,
mad_margin_mid,
p1_mid,
H_mid,
gini_mid,
topk_mean,
topk_std,
topk_cv,
topk_kurt,
topk_med,
s1_over_mean,
s1_over_med,
p1,
p2,
shape_H,
shape_gini,
slope,
z1,
s1_over_sk,
tail_mean,
head5_mean,
mask_ratio,
mask_len,
mask_runs,
],
dim=1,
)
modality_idx = torch.zeros(
B, dtype=torch.long, device=device
)
if "pixel_values" in inputs and inputs["pixel_values"] is not None:
pv = inputs["pixel_values"]
if isinstance(pv, list):
for i, item in enumerate(pv):
if item is not None:
modality_idx[i] = 1
elif isinstance(pv, torch.Tensor) and pv.numel() > 0:
modality_idx.fill_(1)
logits = classifier(scalar_inputs, modality_idx)
p_need_last = torch.sigmoid(logits) # [B,1]
p_need_last_batch = p_need_last.squeeze(1) # [B]
should_exit = p_need_last_batch < threshold
exit_mask = should_exit.cpu().numpy()
if stats["total"] <= B * 3 and is_main:
print_master(
f"[EE Debug] Batch {stats['total']//B}: "
f"p_need_last mean={p_need_last_batch.mean().item():.4f}, "
f"std={p_need_last_batch.std().item():.4f}, "
f"Exit Rate={exit_mask.mean():.2%}, "
f"Top3 Feats: diag_cos={diag_cos.mean():.3f}, "
f"margin={margin_mid.mean():.3f}, H={H_mid.mean():.3f}"
)
stats["exit"] += exit_mask.sum()
# ---------------------------------------------------
# 3. 分支执行
# ---------------------------------------------------
exit_indices = np.where(exit_mask)[0]
cont_indices = np.where(~exit_mask)[0]
# A. 早停样本:用 mid→last 检索
if len(exit_indices) > 0:
reps_exit = reps_mid_t[exit_indices]
if use_local:
for i, idx in enumerate(exit_indices):
cand_local = infos[idx].get("cand_names", [])
if not cand_local:
cids = []
else:
rows = [
cand_id2row.get(str(cid), -1)
for cid in cand_local
]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
cids = []
else:
# CPU FP32:单个 query mid × 局部 cand_mid
cmat_np = cand_mid_np[rows] # [Nc_local, D]
qry_np = reps_exit[i].detach().float().cpu().numpy() # [D]
scores_vec = np.dot(qry_np, cmat_np.T) # [Nc_local]
top_k = min(200, len(rows))
order_local = np.argsort(-scores_vec)[:top_k]
cids = [str(cand_local[o]) for o in order_local]
label = (
infos[idx].get("label_name")
or infos[idx].get("label")
or infos[idx].get("rel_docids")
)
if not isinstance(label, list):
label = [label]
rel_scores = infos[idx].get("rel_scores", None)
global_idx = batch_start_idx + idx
results_dict[global_idx] = {
"prediction": cids,
"label": label,
"rel_scores": rel_scores,
}
else:
# 使用 CPU FP32 Numpy:reps_exit_np × cand_mid_np.T
reps_exit_np = reps_exit.detach().float().cpu().numpy() # [B_exit, D]
scores_full = np.dot(reps_exit_np, cand_mid_np.T) # [B_exit, Nc]
top_k = min(200, len(cand_ids))
# 全排序后截到 top_k
topk_inds = np.argsort(-scores_full, axis=1)[:, :top_k] # [B_exit, top_k]
for i, idx in enumerate(exit_indices):
cids = [cand_ids[k] for k in topk_inds[i]]
label = (
infos[idx].get("label_name")
or infos[idx].get("label")
or infos[idx].get("rel_docids")
)
if not isinstance(label, list):
label = [label]
rel_scores = infos[idx].get("rel_scores", None)
global_idx = batch_start_idx + idx
results_dict[global_idx] = {
"prediction": cids,
"label": label,
"rel_scores": rel_scores,
}
# B. 续跑样本:从 mid 继续到 last
if len(cont_indices) > 0:
interm = getattr(out_mid, "intermediate_state", None)
hs = interm["hidden_states"].detach()
am = interm["attention_mask"].detach()
pos = interm["position_ids"].detach()
vm = interm.get("vision_mask", None)
tm = interm.get("text_mask", None)
next_layer = int(interm["next_layer_idx"])
resume_state_subset = {
"hidden_states": hs[cont_indices],
"attention_mask": am[cont_indices],
"position_ids": pos[:, cont_indices, :],
"vision_mask": vm[cont_indices] if vm is not None else None,
"text_mask": tm[cont_indices] if tm is not None else None,
"next_layer_idx": next_layer,
}
if isinstance(aop_cfg, dict) and aop_cfg:
aop_resume = dict(aop_cfg)
aop_resume["enabled"] = bool(_orig_enabled and side_enable)
setattr(model.encoder, "aop_prune_config", aop_resume)
# 计时:mid -> last
if profile_enabled:
torch.cuda.synchronize()
t0_tail = time.perf_counter()
with torch.no_grad(), torch.autocast(
device_type="cuda", dtype=torch.bfloat16, enabled=True
):
out_last = model.encoder(
return_dict=True,
output_hidden_states=False,
stop_at_layer=None,
resume_state=resume_state_subset,
compute_lm_head=False,
)
if profile_enabled:
torch.cuda.synchronize()
t1_tail = time.perf_counter()
timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len(
cont_indices
)
timing_stats["tail_num"] += len(cont_indices)
hs_last = out_last.last_hidden_state
if hs_last is None:
hs_last = out_last.hidden_states[-1]
am_last = getattr(out_last, "attention_mask", None)
if am_last is None:
am_last = resume_state_subset["attention_mask"]
reps_last_t = (
model._pooling(hs_last, am_last)
.detach()
.to(dtype=torch.bfloat16)
)
if use_local:
for i, idx_global in enumerate(cont_indices):
cand_local = infos[idx_global].get("cand_names", [])
if not cand_local:
cids = []
else:
rows = [
cand_id2row.get(str(cid), -1)
for cid in cand_local
]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
cids = []
else:
# CPU FP32:单个 query last × 局部 cand_last
cmat_last_np = cand_last_np[rows] # [Nc_local, D]
qry_last_np = reps_last_t[i].detach().float().cpu().numpy()
scores_vec = np.dot(qry_last_np, cmat_last_np.T) # [Nc_local]
top_k = min(200, len(rows))
order_local = np.argsort(-scores_vec)[:top_k]
cids = [str(cand_local[o]) for o in order_local]
label = (
infos[idx_global].get("label_name")
or infos[idx_global].get("label")
or infos[idx_global].get("rel_docids")
)
if not isinstance(label, list):
label = [label]
rel_scores = infos[idx_global].get("rel_scores", None)
global_idx = batch_start_idx + idx_global
results_dict[global_idx] = {
"prediction": cids,
"label": label,
"rel_scores": rel_scores,
}
else:
# CPU FP32:所有续跑 query last × 全量 cand_last
reps_last_np = reps_last_t.detach().float().cpu().numpy() # [B_cont, D]
scores_last = np.dot(reps_last_np, cand_last_np.T) # [B_cont, Nc]
top_k = min(200, len(cand_ids))
topk_inds = np.argsort(-scores_last, axis=1)[:, :top_k] # [B_cont, top_k]
for i, idx_global in enumerate(cont_indices):
cids = [cand_ids[k] for k in topk_inds[i]]
label = (
infos[idx_global].get("label_name")
or infos[idx_global].get("label")
or infos[idx_global].get("rel_docids")
)
if not isinstance(label, list):
label = [label]
rel_scores = infos[idx_global].get("rel_scores", None)
global_idx = batch_start_idx + idx_global
results_dict[global_idx] = {
"prediction": cids,
"label": label,
"rel_scores": rel_scores,
}
# ---------------------------------------------------
# 4. Profiling: 保存 per-query 的 topK embedding & 相似度
# ---------------------------------------------------
if profile_enabled and is_main:
K = min(topk_emb, cand_mid_t.size(0))
# 转到 float32 + CPU 便于写盘
q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D]
q_last_cpu = (
reps_last_full.detach().float().cpu()
if reps_last_full is not None
else None
) # [B, D]
cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D]
cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D]
# 【关键修复】mid2mid 相似度(表征空间对齐)
scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc]
topk_mid_vals, topk_mid_inds = torch.topk(
scores_mid_full, k=K, dim=1
)
# last2last 相似度(如果有)
if q_last_cpu is not None:
scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc]
topk_last_vals, topk_last_inds = torch.topk(
scores_last_full, k=K, dim=1
)
else:
topk_last_vals = None
topk_last_inds = None
for i in range(B):
qid = batch_start_idx + i
rec = {
"qid": int(qid),
"early_exit": bool(exit_mask[i]),
}
if p_need_last_batch is not None:
rec["p_need_last"] = float(p_need_last_batch[i].item())
# 【关键修复】mid2mid TopK(表征空间对齐)
mid_inds = topk_mid_inds[i].tolist()
mid_scores = topk_mid_vals[i].tolist()
rec["mid_topk_scores"] = mid_scores
rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds]
rec["mid_q_emb"] = q_mid_cpu[i].tolist()
rec["mid_cand_embs"] = cand_mid_cpu[mid_inds].tolist()
# last2last TopK(如果有)
if topk_last_inds is not None:
last_inds = topk_last_inds[i].tolist()
last_scores = topk_last_vals[i].tolist()
rec["last_topk_scores"] = last_scores
rec["last_topk_cand_ids"] = [
cand_ids[j] for j in last_inds
]
rec["last_q_emb"] = (
q_last_cpu[i].tolist() if q_last_cpu is not None else None
)
rec["last_cand_embs"] = cand_last_cpu[last_inds].tolist()
else:
rec["last_topk_scores"] = None
rec["last_topk_cand_ids"] = None
rec["last_q_emb"] = None
rec["last_cand_embs"] = None
analysis_records.append(rec)
# =====================================================
# 5. 收集 & 保存结果
# =====================================================
for idx in sorted(results_dict.keys()):
pred_dicts.append(results_dict[idx])
print_master(
f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} "
f"({stats['exit']/stats['total']:.2%})"
)
metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
score = RankingMetrics(metrics_to_report).evaluate(pred_dicts)
if is_main:
os.makedirs(out_dir, exist_ok=True)
with open(
os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"),
"w",
) as f:
json.dump(score, f, indent=4)
# 保存 profiling 信息
if profile_enabled:
prof_dir = os.path.join(out_dir, "profiling")
os.makedirs(prof_dir, exist_ok=True)
# 时间统计:转换成均值(秒/样本)
mid_avg = (
timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"])
)
tail_avg = (
timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"])
)
timing_out = {
"mid_time_sum": timing_stats["mid_time_sum"],
"mid_num": timing_stats["mid_num"],
"tail_time_sum": timing_stats["tail_time_sum"],
"tail_num": timing_stats["tail_num"],
"avg_mid_time_per_query_sec": mid_avg,
"avg_tail_time_per_cont_query_sec": tail_avg,
"num_exit": int(stats["exit"]),
"num_total": int(stats["total"]),
}
with open(
os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w"
) as f:
json.dump(timing_out, f, indent=2)
# embedding + 相似度记录(JSONL)
embed_path = os.path.join(
prof_dir, f"{dataset_name}_embeds.jsonl"
)
with open(embed_path, "w") as f:
for rec in analysis_records:
f.write(json.dumps(rec) + "\n")
print_master(
f"[PROFILE] Saved timing to {prof_dir}, "
f"embeddings to {embed_path}"
)
elapsed = time.time() - start_time
return score, elapsed
# ===========================
# Helper Functions (Pre-Computation)
# ===========================
def encode_candidates_both_layers(
model: MMEBModel, loader: DataLoader, training_args: TrainingArguments,
model_args: ModelArguments, full_dataset: Dataset, mid_layer: int,
):
local_rank = dist.get_rank() if dist.is_initialized() else 0
model.eval()
all_mid, all_last, all_ids = [], [], []
with torch.no_grad():
for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0):
inputs = batch_to_device(inputs, training_args.device)
# 强制关闭 Cand 侧的 AOP
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig = None
if isinstance(aop_cfg, dict):
_orig = aop_cfg.get("enabled", False)
aop_cfg["enabled"] = False
setattr(model.encoder, "aop_prune_config", aop_cfg)
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
out = model.encoder(**inputs, return_dict=True, output_hidden_states=True, stop_at_layer=None)
if isinstance(aop_cfg, dict) and _orig is not None:
aop_cfg["enabled"] = _orig # Restore
mid_hs = out.hidden_states[mid_layer]
last_hs = out.hidden_states[-1]
am = inputs.get("attention_mask", None)
if am is not None and am.device != mid_hs.device: am = am.to(mid_hs.device)
reps_mid = model._pooling(mid_hs, am)
reps_last = model._pooling(last_hs, am)
all_mid.append(reps_mid.detach().float().cpu())
all_last.append(reps_last.detach().float().cpu())
all_ids.extend([info["cand_name"] for info in dataset_info])
if not all_mid: return np.array([]), np.array([]), []
return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids
# ===========================
# Main
# ===========================
def main():
if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
local_rank = dist.get_rank() if dist.is_initialized() else 0
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
ee_cfg = get_env_ee_config()
hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
if not getattr(model_args, "model_backbone", None):
model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
if local_rank == 0:
processor = load_processor(model_args, data_args)
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
if torch.distributed.is_initialized(): torch.distributed.barrier()
if local_rank != 0:
processor = load_processor(model_args, data_args)
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
model.eval()
model = model.to(training_args.device, dtype=torch.bfloat16)
# 调试:检查模型配置
print_master(f"[DEBUG] Model normalize={model.normalize}, pooling={model.pooling}, temperature={getattr(model, 'temperature', 'N/A')}")
# AOP 配置注入
aop_cfg = get_env_aop_config()
if aop_cfg["enabled"]:
setattr(model.encoder, "aop_prune_config", aop_cfg)
model.set_inference_layers(qry_layers=None, tgt_layers=None)
print_master(f"[AOP] Enabled: Layer={aop_cfg['layer_idx']}, Ratio={aop_cfg['keep_ratio']}")
# 加载分类器
classifier = None
if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]:
classifier_path = ee_cfg['classifier_path']
print_master(f"[EE] Loading Classifier from {classifier_path}...")
# 【关键】使用27维特征,与训练代码完全一致
classifier = EarlyExitClassifier(input_dim=27, hidden_dim=64)
state_dict = None
# 尝试多种加载方式
if os.path.isdir(classifier_path):
# 方式1: safetensors 格式(Trainer 默认)
safetensors_file = os.path.join(classifier_path, "model.safetensors")
if os.path.exists(safetensors_file):
from safetensors.torch import load_file
state_dict = load_file(safetensors_file)
print_master(f"[EE] Loaded from model.safetensors")
else:
# 方式2: pytorch_model.bin
pt_file = os.path.join(classifier_path, "pytorch_model.bin")
if os.path.exists(pt_file):
state_dict = torch.load(pt_file, map_location=training_args.device)
print_master(f"[EE] Loaded from pytorch_model.bin")
else:
# 方式3: 独立的 .pt 文件
layer_idx = ee_cfg.get('layer', 12)
pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt")
if os.path.exists(pt_file):
state_dict = torch.load(pt_file, map_location=training_args.device)
print_master(f"[EE] Loaded from {os.path.basename(pt_file)}")
else:
raise FileNotFoundError(f"Cannot find classifier weights in {classifier_path}")
elif os.path.isfile(classifier_path):
# 方式4: 直接指定文件
if classifier_path.endswith('.safetensors'):
from safetensors.torch import load_file
state_dict = load_file(classifier_path)
print_master(f"[EE] Loaded from .safetensors file")
else:
state_dict = torch.load(classifier_path, map_location=training_args.device)
print_master(f"[EE] Loaded from .pt file")
else:
raise FileNotFoundError(f"Classifier path not found: {classifier_path}")
classifier.load_state_dict(state_dict)
classifier.to(training_args.device)
classifier.eval()
print_master(f"[EE] Classifier loaded successfully. Threshold={ee_cfg['threshold']}")
with open(data_args.dataset_config, 'r') as yaml_file:
dataset_configs = yaml.safe_load(yaml_file)
for dataset_name, task_config in dataset_configs.items():
if dist.is_initialized(): dist.barrier()
print_master(f"\n--- Evaluating {dataset_name} ---")
# ===== 为当前数据集选择专用阈值 τ_ds =====
base_tau = float(ee_cfg["threshold"]) # 环境变量 EE_THRESHOLD 作为默认
ds_tau = PER_DATASET_THRESHOLDS.get(dataset_name, base_tau)
# 为当前 dataset 拷贝一份 ee_cfg,并覆盖 threshold
ee_cfg_ds = dict(ee_cfg)
ee_cfg_ds["threshold"] = float(ds_tau)
print_master(
f"[EE] Dataset '{dataset_name}' use threshold τ={ee_cfg_ds['threshold']:.6e} "
f"(default={base_tau:.6e})"
)
if data_args.data_basedir:
for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
if task_config.get(key):
task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(
model_args=model_args, data_args=data_args, **task_config
)
full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
mid_layer = int(ee_cfg_ds["layer"])
cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}")
cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast")
force_regenerate = False
if force_regenerate or (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)):
print_master(f"[INFO] Regenerating candidates with normalize={model.normalize}...")
eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
eval_cand_loader = DataLoader(
full_eval_cand_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=eval_cand_collator,
num_workers=training_args.dataloader_num_workers,
)
cand_mid, cand_last, cand_ids = encode_candidates_both_layers(
model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer
)
if local_rank == 0:
with open(cand_mid_path, "wb") as f:
pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f)
with open(cand_last_path, "wb") as f:
pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f)
if dist.is_initialized():
dist.barrier()
if local_rank == 0:
with open(cand_mid_path, "rb") as f:
cand_mid_dict = pickle.load(f)
with open(cand_last_path, "rb") as f:
cand_last_dict = pickle.load(f)
rank_global = task_config.get("eval_type", "global") == "global"
print_master(f"[{dataset_name}] Eval type: {'global' if rank_global else 'local'} ranking")
# 关键:传 ee_cfg_ds,而不是全局 ee_cfg
run_early_exit_queries(
model, classifier, processor, model_args, data_args, training_args,
full_eval_qry_dataset, cand_mid_dict, cand_last_dict,
ee_cfg_ds, dataset_name, data_args.encode_output_path,
global_ranking=rank_global,
)
if dist.is_initialized():
dist.barrier()
if dist.is_initialized(): dist.barrier()
if __name__ == '__main__':
main()