| # import datetime | |
| # import logging | |
| # import json | |
| # import random | |
| # import time | |
| # import numpy as np | |
| # import os | |
| # import pickle | |
| # import sys | |
| # import torch | |
| # import torch.distributed as dist | |
| # import torch.nn.functional as F | |
| # import yaml | |
| # import transformers | |
| # from torch.utils.data import DataLoader | |
| # from tqdm import tqdm | |
| # from transformers import HfArgumentParser, AutoConfig, AutoTokenizer | |
| # from datasets import Dataset, concatenate_datasets | |
| # from datasets.distributed import split_dataset_by_node | |
| # from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src | |
| # from src.arguments import ModelArguments, DataArguments, TrainingArguments | |
| # from src.data.collator.eval_collator import MultimodalEvalDataCollator | |
| # from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset | |
| # from src.eval_utils.metrics import RankingMetrics | |
| # from src.model.model_cut_layer import MMEBModel | |
| # from src.model.processor import get_backbone_name, load_processor, COLPALI | |
| # from src.utils import batch_to_device, print_rank, print_master | |
| # logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') | |
| # logger = logging.getLogger(__name__) | |
| # # --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) --- | |
| # timing_info = {} | |
| # token_info = { | |
| # "vision_tokens": 0, | |
| # "text_input_tokens": 0, # Refers to the original text token count | |
| # "text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0. | |
| # "total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text) | |
| # } | |
| # # --- Hook Functions Definition --- | |
| # def timing_pre_hook(module, input): | |
| # module_id = id(module) | |
| # if module_id not in timing_info: | |
| # timing_info[module_id] = [] | |
| # timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__)) | |
| # def timing_post_hook(module, input, output): | |
| # module_id = id(module) | |
| # if module_id not in timing_info: | |
| # # print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})") | |
| # return | |
| # timing_info[module_id].append((time.time(), 'post', module.__class__.__name__)) | |
| # # Collect vision token count (only from Vision Transformer module's post hook) | |
| # module_name = module.__class__.__name__ | |
| # if "vision" in module_name.lower() and "transformer" in module_name.lower(): | |
| # if isinstance(output, torch.Tensor): | |
| # token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim) | |
| # elif hasattr(output, 'last_hidden_state'): | |
| # token_info["vision_tokens"] = output.last_hidden_state.shape[1] | |
| # def register_model_hooks(model): | |
| # registered_modules = [] | |
| # core_model = model | |
| # # print_master(f"DEBUG: Initial model type in register_model_hooks: {type(model)}") | |
| # if hasattr(model, 'encoder') and model.encoder is not None: | |
| # # print_master(f"DEBUG: model has 'encoder' attribute. Type of model.encoder: {type(model.encoder)}") | |
| # # 使用从 'src' 路径导入的 Qwen2VLForConditionalGeneration 进行检查 | |
| # if isinstance(model.encoder, _Qwen2VLForConditionalGeneration_src): | |
| # # print_master("Detected MMEBModel structure, registering hooks on model.encoder's sub-modules.") | |
| # core_model = model.encoder | |
| # else: | |
| # print_master(f"WARNING: model.encoder is not an instance of _Qwen2VLForConditionalGeneration_src. Its type is {type(model.encoder)}. Hooks will be registered on top-level model if applicable.") | |
| # else: | |
| # print_master("WARNING: Model structure does not have an 'encoder' attribute. Registering hooks directly on top-level modules.") | |
| # # Vision module | |
| # if hasattr(core_model, 'visual') and core_model.visual is not None: | |
| # vision_module = core_model.visual | |
| # vision_module.register_forward_pre_hook(timing_pre_hook) | |
| # vision_module.register_forward_hook(timing_post_hook) | |
| # registered_modules.append(vision_module) | |
| # print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}") | |
| # else: | |
| # print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).") | |
| # # Merger module (if inside visual) - it's part of the vision component | |
| # if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None: | |
| # merger_module = core_model.visual.merger | |
| # merger_module.register_forward_pre_hook(timing_pre_hook) | |
| # merger_module.register_forward_hook(timing_post_hook) | |
| # registered_modules.append(merger_module) | |
| # print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}") | |
| # else: | |
| # print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).") | |
| # # Language model body | |
| # if hasattr(core_model, 'model') and core_model.model is not None: | |
| # llm_main_module = core_model.model | |
| # llm_main_module.register_forward_pre_hook(timing_pre_hook) | |
| # llm_main_module.register_forward_hook(timing_post_hook) | |
| # registered_modules.append(llm_main_module) | |
| # print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}") | |
| # else: | |
| # print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).") | |
| # # LM Head | |
| # if hasattr(core_model, 'lm_head') and core_model.lm_head is not None: | |
| # lm_head_module = core_model.lm_head | |
| # lm_head_module.register_forward_pre_hook(timing_pre_hook) | |
| # lm_head_module.register_forward_hook(timing_post_hook) | |
| # registered_modules.append(lm_head_module) | |
| # print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}") | |
| # else: | |
| # print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).") | |
| # if not registered_modules: | |
| # print_master("Warning: No major modules found for hook registration. Check model architecture.") | |
| # return registered_modules | |
| # def pad_dataset_to_divisible(dataset, world_size): | |
| # num_samples = len(dataset) | |
| # if num_samples % world_size == 0: | |
| # return dataset, num_samples | |
| # num_to_add = world_size - (num_samples % world_size) | |
| # padded_size = num_samples + num_to_add | |
| # padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) | |
| # padded_dataset = concatenate_datasets([dataset, padding_data]) | |
| # return padded_dataset, padded_size | |
| # def encode_embeddings( | |
| # model: MMEBModel, | |
| # loader: DataLoader, | |
| # training_args: TrainingArguments, | |
| # model_args: ModelArguments, | |
| # full_dataset: Dataset, | |
| # encode_side: str, | |
| # description: str = "Encoding" | |
| # ) -> tuple[np.ndarray, list, list, list]: # CHANGED: + list for img_token_masks | |
| # """ | |
| # Encodes embeddings for a given dataset using the model, handling both standard and | |
| # late-interaction models in a DDP-safe manner. | |
| # Returns: | |
| # - embeddings: np.ndarray | |
| # - infos_or_ids: list | |
| # - batch_stats_list: list | |
| # - img_token_masks: list[None | list[bool]] # NEW | |
| # """ | |
| # local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| # world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
| # # Check if the model is a late-interaction type | |
| # is_late_interaction = (model_args.model_backbone == COLPALI) | |
| # local_embeds = [] | |
| # local_gt_infos = [] | |
| # local_max_len = 0 | |
| # # --- New: List to store statistics for each batch --- | |
| # batch_stats_list = [] | |
| # # --- NEW: Collect image token masks locally --- | |
| # local_img_token_masks = [] # 每个样本一个元素:None 或 [bool, ...] | |
| # model.eval() | |
| # # Register hooks for the model once per encode_embeddings call | |
| # registered_hooks = register_model_hooks(model) | |
| # # --- NEW: helpers to取mask并序列化 --- | |
| # def _search_key(obj, key: str): | |
| # # 递归搜索 dict/list/tuple,找到指定 key | |
| # if isinstance(obj, dict): | |
| # if key in obj: | |
| # return obj[key] | |
| # for v in obj.values(): | |
| # r = _search_key(v, key) | |
| # if r is not None: | |
| # return r | |
| # elif isinstance(obj, (list, tuple)): | |
| # for v in obj: | |
| # r = _search_key(v, key) | |
| # if r is not None: | |
| # return r | |
| # return None | |
| # def _to_serializable_mask_list(mask_list, batch_size: int): | |
| # # 将模型返回的 mask(list/tensor/ndarray/None)转成 [None | list[bool]] * B | |
| # if mask_list is None: | |
| # return [None] * batch_size | |
| # out = [] | |
| # if isinstance(mask_list, (list, tuple)): | |
| # for m in mask_list: | |
| # if m is None: | |
| # out.append(None) | |
| # elif torch.is_tensor(m): | |
| # out.append(m.detach().cpu().tolist()) | |
| # elif isinstance(m, np.ndarray): | |
| # out.append(m.tolist()) | |
| # else: | |
| # # already python list/bool | |
| # out.append(m) | |
| # elif torch.is_tensor(mask_list): | |
| # # 若是 2D 张量(B, L),直接 tolist() -> list[list[bool/int]] | |
| # out = mask_list.detach().cpu().tolist() | |
| # elif isinstance(mask_list, np.ndarray): | |
| # out = mask_list.tolist() | |
| # else: | |
| # # 未知类型,保守返回 None 占位 | |
| # out = [None] * batch_size | |
| # # 长度对齐 batch_size | |
| # if isinstance(out, list): | |
| # if len(out) < batch_size: | |
| # out = out + [None] * (batch_size - len(out)) | |
| # elif len(out) > batch_size: | |
| # out = out[:batch_size] | |
| # return out | |
| # with torch.no_grad(): | |
| # for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): | |
| # # --- Reset statistics for each inference pass --- | |
| # timing_info.clear() | |
| # token_info["vision_tokens"] = 0 | |
| # token_info["text_input_tokens"] = 0 | |
| # token_info["text_output_tokens"] = 0 | |
| # token_info["total_llm_input_tokens"] = 0 | |
| # inputs = batch_to_device(inputs, training_args.device) | |
| # current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1 | |
| # with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): | |
| # start_inference_time = time.time() | |
| # if encode_side == "qry": | |
| # output = model(qry=inputs) | |
| # reps = output["qry_reps"].detach() | |
| # local_gt_infos.extend(dataset_info) | |
| # else: | |
| # output = model(tgt=inputs) | |
| # reps = output["tgt_reps"].detach() | |
| # local_gt_infos.extend([info["cand_name"] for info in dataset_info]) | |
| # end_inference_time = time.time() | |
| # # --- NEW: 提取并保存本 batch 的 image_token_bool_masks --- | |
| # # 期望 MMEBModel 的 output 中直接或间接包含 'image_token_bool_masks' | |
| # img_masks_raw = None | |
| # if isinstance(output, dict): | |
| # img_masks_raw = _search_key(output, "image_token_bool_masks") | |
| # # 可选:若你在 MMEBModel 上挂了属性,也可以尝试读取 | |
| # if img_masks_raw is None and hasattr(model, "image_token_bool_masks"): | |
| # img_masks_raw = getattr(model, "image_token_bool_masks") | |
| # img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size) | |
| # local_img_token_masks.extend(img_masks_serializable) | |
| # # --- Update total LLM input tokens after the model call --- | |
| # if 'input_ids' in inputs and inputs['input_ids'] is not None: | |
| # token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1] | |
| # token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"] | |
| # token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"]) | |
| # # --- Collect and Store Batch Statistics --- | |
| # batch_inference_time = end_inference_time - start_inference_time | |
| # current_batch_stats = { | |
| # "batch_size": current_batch_size, | |
| # "total_inference_time_seconds": batch_inference_time, | |
| # "module_inference_times": {}, | |
| # "token_counts": { | |
| # "visual_tokens": token_info["vision_tokens"], | |
| # "language_input_tokens_raw": token_info["text_input_tokens"], | |
| # "llm_total_input_tokens": token_info["total_llm_input_tokens"], | |
| # "language_output_tokens": token_info["text_output_tokens"], | |
| # } | |
| # } | |
| # # Calculate and store module timings for the current batch | |
| # for module_obj in registered_hooks: | |
| # module_id = id(module_obj) | |
| # module_name = module_obj.__class__.__name__ | |
| # times = timing_info.get(module_id, []) | |
| # durations = [] | |
| # pre_times = {} | |
| # for t, event_type, _ in times: | |
| # if event_type == 'pre': | |
| # pre_times[module_id] = t | |
| # elif event_type == 'post' and module_id in pre_times: | |
| # duration = t - pre_times.pop(module_id) | |
| # durations.append(duration) | |
| # if durations: | |
| # current_batch_stats["module_inference_times"][module_name] = { | |
| # "total": sum(durations), | |
| # "count": len(durations), | |
| # "avg": sum(durations) / len(durations) | |
| # } | |
| # else: | |
| # current_batch_stats["module_inference_times"][module_name] = { | |
| # "total": 0.0, | |
| # "count": 0, | |
| # "avg": 0.0 | |
| # } | |
| # batch_stats_list.append(current_batch_stats) | |
| # # --- Debug prints (optional) --- | |
| # print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---") | |
| # print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds") | |
| # print_rank("--- Module Inference Timing Statistics ---") | |
| # for module_name, stats in current_batch_stats["module_inference_times"].items(): | |
| # print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s") | |
| # print_rank("--- Token Count Statistics ---") | |
| # print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}") | |
| # print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}") | |
| # print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}") | |
| # print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}") | |
| # if is_late_interaction and reps.dim() == 3: | |
| # local_max_len = max(local_max_len, reps.shape[1]) | |
| # local_embeds.append(reps) | |
| # if not local_embeds: | |
| # # Handle cases where a rank gets no data | |
| # return np.array([]), [], [], [] # CHANGED: 4个返回值 | |
| # # === DDP Synchronization and Padding for Late-Interaction Models === | |
| # if is_late_interaction: | |
| # if dist.is_initialized(): | |
| # # 1: global max length | |
| # local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device) | |
| # dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX) | |
| # global_max_len = local_max_len_tensor.item() | |
| # else: | |
| # global_max_len = local_max_len | |
| # # 2: pad to global max length | |
| # padded_embeds = [] | |
| # for reps_batch in local_embeds: | |
| # if reps_batch.dim() == 3: | |
| # B, L, H = reps_batch.shape | |
| # padding_size = global_max_len - L | |
| # padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0) | |
| # padded_embeds.append(padded_batch) | |
| # else: | |
| # padded_embeds.append(reps_batch) | |
| # embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous() | |
| # else: | |
| # embeds_tensor = torch.cat(local_embeds, dim=0).contiguous() | |
| # # === Gather embeddings and keys from all ranks === | |
| # if dist.is_initialized() and full_dataset.num_rows >= world_size: | |
| # print_master(f"Gathering {encode_side} embeddings across all ranks...") | |
| # # tensor gather | |
| # output_shape = list(embeds_tensor.shape) | |
| # output_shape[0] = full_dataset.num_rows | |
| # embeds_tensor = embeds_tensor.to(training_args.device) | |
| # gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) | |
| # dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor) | |
| # final_embeddings = gathered_embeds_tensor.cpu().float().numpy() | |
| # # object gather for infos and stats | |
| # gathered_gt_infos = [None for _ in range(world_size)] | |
| # dist.all_gather_object(gathered_gt_infos, local_gt_infos) | |
| # all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys] | |
| # gathered_batch_stats = [None for _ in range(world_size)] | |
| # dist.all_gather_object(gathered_batch_stats, batch_stats_list) | |
| # all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats] | |
| # # --- NEW: gather masks --- | |
| # gathered_masks = [None for _ in range(world_size)] | |
| # dist.all_gather_object(gathered_masks, local_img_token_masks) | |
| # all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list] | |
| # else: | |
| # all_gt_infos = local_gt_infos | |
| # final_embeddings = embeds_tensor.cpu().float().numpy() | |
| # all_batch_stats = batch_stats_list | |
| # all_img_token_masks = local_img_token_masks # NEW | |
| # return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks # CHANGED | |
| # def main(): | |
| # if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): | |
| # dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) | |
| # local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| # world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
| # # DEBUG PRINTS for Distributed Setup | |
| # print_master("Distributed init debug info:") | |
| # print_master(f"RANK: {os.environ.get('RANK')}") | |
| # print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}") | |
| # print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}") | |
| # print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}") | |
| # print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}") | |
| # if dist.is_initialized(): | |
| # print_rank(f"dist.get_rank(): {dist.get_rank()}") | |
| # print_rank(f"dist.get_world_size(): {dist.get_world_size()}") | |
| # for arg in sys.argv: | |
| # if arg.startswith("--local-rank="): | |
| # rank = arg.split("=")[1] | |
| # sys.argv.remove(arg) | |
| # sys.argv.append('--local_rank') | |
| # sys.argv.append(rank) | |
| # parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) | |
| # model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| # model_args: ModelArguments | |
| # data_args: DataArguments | |
| # training_args: TrainingArguments | |
| # os.makedirs(data_args.encode_output_path, exist_ok=True) | |
| # # --- Model Loading --- | |
| # hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| # if not getattr(model_args, "model_backbone", None): | |
| # model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) | |
| # setattr(model_args, 'model_backbone', model_backbone) | |
| # setattr(training_args, 'model_backbone', model_backbone) | |
| # print_master(f'Model Backbone: {model_args.model_backbone}') | |
| # # --- DDP-Safe Model Loading --- | |
| # # Step 1: Only the master process (rank 0) downloads the model. | |
| # if local_rank == 0: | |
| # processor = load_processor(model_args, data_args) | |
| # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) | |
| # print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...") | |
| # # Step 2: All processes wait here. The non-master processes will pause | |
| # # until the master process (rank 0) finishes downloading and exits this barrier. | |
| # if torch.distributed.is_initialized(): | |
| # torch.distributed.barrier() | |
| # # Step 3: Now that the model is cached, the non-master processes load it from the local cache. | |
| # if local_rank != 0: | |
| # print_rank(f"Loading the model from cache...") | |
| # processor = load_processor(model_args, data_args) | |
| # time.sleep(random.randint(2 * local_rank, 3 * local_rank)) | |
| # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) | |
| # model.eval() | |
| # model = model.to(training_args.device, dtype=torch.bfloat16) | |
| # with open(data_args.dataset_config, 'r') as yaml_file: | |
| # dataset_configs = yaml.safe_load(yaml_file) | |
| # # --- Main Evaluation Loop --- | |
| # for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()): | |
| # # Initialize task-level statistics accumulators for QUERY | |
| # query_total_stats = { | |
| # "total_inference_time_seconds": 0.0, | |
| # "module_inference_times": { | |
| # "Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0}, | |
| # "PatchMerger": {"total": 0.0, "count": 0}, | |
| # "Qwen2VLModel": {"total": 0.0, "count": 0}, | |
| # "Linear": {"total": 0.0, "count": 0}, | |
| # }, | |
| # "token_counts": { | |
| # "visual_tokens": 0, | |
| # "language_input_tokens_raw": 0, | |
| # "llm_total_input_tokens": 0, | |
| # "language_output_tokens": 0, | |
| # }, | |
| # "data_point_count": 0 # Number of image-text pairs processed | |
| # } | |
| # # Initialize task-level statistics accumulators for CANDIDATE | |
| # cand_total_stats = { | |
| # "total_inference_time_seconds": 0.0, | |
| # "module_inference_times": { | |
| # "Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0}, | |
| # "PatchMerger": {"total": 0.0, "count": 0}, | |
| # "Qwen2VLModel": {"total": 0.0, "count": 0}, | |
| # "Linear": {"total": 0.0, "count": 0}, | |
| # }, | |
| # "token_counts": { | |
| # "visual_tokens": 0, | |
| # "language_input_tokens_raw": 0, | |
| # "llm_total_input_tokens": 0, | |
| # "language_output_tokens": 0, | |
| # }, | |
| # "data_point_count": 0 # Number of image-text pairs processed | |
| # } | |
| # if dist.is_initialized(): | |
| # dist.barrier() | |
| # print_master(f"\n--- Evaluating {dataset_name} ---") | |
| # query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry") | |
| # cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt") | |
| # dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") | |
| # # New: Define distinct paths for query and candidate inference statistics output | |
| # query_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_inference_stats.json") | |
| # cand_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_inference_stats.json") | |
| # do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path) | |
| # do_cand = not os.path.exists(cand_embed_path) | |
| # if do_query or do_cand: | |
| # if data_args.data_basedir is not None: | |
| # # Construct full paths for data files if --data_basedir is provided | |
| # for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: | |
| # if data_args.data_basedir and task_config.get(key): | |
| # task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) | |
| # full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) | |
| # full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) | |
| # eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset | |
| # # Pad datasets to be divisible by world_size before splitting | |
| # if dist.is_initialized(): | |
| # padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size) | |
| # padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size) | |
| # eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size) | |
| # eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size) | |
| # else: | |
| # padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset | |
| # # --- 1. Compute Query Embeddings --- | |
| # if do_query: | |
| # print_master("Encoding queries...") | |
| # eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") | |
| # eval_qry_loader = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers) | |
| # # Modified: capture batch_stats_list | |
| # query_embeds, gt_infos, qry_batch_stats, qry_img_masks = encode_embeddings(model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries for {dataset_name}") | |
| # # Accumulate query statistics | |
| # for batch_stat in qry_batch_stats: | |
| # batch_size = batch_stat["batch_size"] | |
| # query_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"] | |
| # for module_name, module_stats in batch_stat["module_inference_times"].items(): | |
| # if module_name in query_total_stats["module_inference_times"]: | |
| # query_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"] | |
| # query_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"] | |
| # query_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size | |
| # query_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size | |
| # query_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size | |
| # query_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size | |
| # query_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items | |
| # query_embeds = query_embeds[:len(full_eval_qry_dataset)] | |
| # gt_infos = gt_infos[:len(full_eval_qry_dataset)] | |
| # if local_rank == 0: | |
| # with open(query_embed_path, 'wb') as f: | |
| # pickle.dump(query_embeds, f) | |
| # with open(dataset_info_path, 'w') as f: | |
| # for info in gt_infos: | |
| # f.write(json.dumps(info) + '\n') | |
| # print_master(f"Saved query embeddings to {query_embed_path}") | |
| # qry_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_img_token_masks.jsonl") | |
| # with open(qry_img_masks_path, 'w', encoding='utf-8') as f: | |
| # for i, m in enumerate(qry_img_masks[:len(full_eval_qry_dataset)]): | |
| # f.write(json.dumps({"index": i, "mask": m}, ensure_ascii=False) + "\n") | |
| # print_master(f"Saved query image token masks to {qry_img_masks_path}") | |
| # # Save query-specific inference statistics | |
| # if query_total_stats["data_point_count"] > 0: | |
| # final_query_stats = { | |
| # "task_name": dataset_name, | |
| # "encode_side": "query", | |
| # "data_point_count": query_total_stats["data_point_count"], | |
| # "inference_times": { | |
| # "total_inference_time_seconds": query_total_stats["total_inference_time_seconds"], | |
| # "avg_inference_time_per_item_seconds": query_total_stats["total_inference_time_seconds"] / query_total_stats["data_point_count"], | |
| # "module_average_times_per_call": {}, | |
| # "module_total_times_seconds": {}, | |
| # "module_calls_count": {}, | |
| # }, | |
| # "token_counts": { | |
| # "total_visual_tokens": query_total_stats["token_counts"]["visual_tokens"], | |
| # "avg_visual_tokens_per_item": query_total_stats["token_counts"]["visual_tokens"] / query_total_stats["data_point_count"], | |
| # "total_language_input_tokens_raw": query_total_stats["token_counts"]["language_input_tokens_raw"], | |
| # "avg_language_input_tokens_raw_per_item": query_total_stats["token_counts"]["language_input_tokens_raw"] / query_total_stats["data_point_count"], | |
| # "total_llm_total_input_tokens": query_total_stats["token_counts"]["llm_total_input_tokens"], | |
| # "avg_llm_total_input_tokens_per_item": query_total_stats["token_counts"]["llm_total_input_tokens"] / query_total_stats["data_point_count"], | |
| # "total_language_output_tokens": query_total_stats["token_counts"]["language_output_tokens"], | |
| # "avg_language_output_tokens_per_item": query_total_stats["token_counts"]["language_output_tokens"] / query_total_stats["data_point_count"], | |
| # } | |
| # } | |
| # for module_name, stats in query_total_stats["module_inference_times"].items(): | |
| # final_query_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"] | |
| # final_query_stats["inference_times"]["module_calls_count"][module_name] = stats["count"] | |
| # if stats["count"] > 0: | |
| # final_query_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"] | |
| # else: | |
| # final_query_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0 | |
| # with open(query_inference_stats_path, 'w', encoding='utf-8') as f: | |
| # json.dump(final_query_stats, f, ensure_ascii=False, indent=4) | |
| # print_master(f"Query inference statistics for {dataset_name} saved to: {query_inference_stats_path}") | |
| # else: | |
| # print_master(f"No query data processed for {dataset_name}, skipping query inference statistics output.") | |
| # if dist.is_initialized(): | |
| # dist.barrier() | |
| # # --- 2. Compute Candidate Embeddings --- | |
| # if do_cand: | |
| # print_master("Encoding candidates...") | |
| # eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") | |
| # eval_cand_loader = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) | |
| # # Modified: capture batch_stats_list | |
| # cand_embeds, all_cand_ids, cand_batch_stats, cand_img_masks = encode_embeddings(model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates for {dataset_name}") | |
| # # Accumulate candidate statistics (similar logic as query) | |
| # for batch_stat in cand_batch_stats: | |
| # batch_size = batch_stat["batch_size"] | |
| # cand_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"] | |
| # for module_name, module_stats in batch_stat["module_inference_times"].items(): | |
| # if module_name in cand_total_stats["module_inference_times"]: | |
| # cand_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"] | |
| # cand_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"] | |
| # cand_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size | |
| # cand_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size | |
| # cand_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size | |
| # cand_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size | |
| # cand_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items | |
| # cand_embeds = cand_embeds[:len(full_eval_cand_dataset)] | |
| # all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)] | |
| # if local_rank == 0: | |
| # cand_embed_dict = {cand_id: embed for cand_id, embed in zip(all_cand_ids, cand_embeds)} | |
| # with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f) | |
| # print_master(f"Saved candidate embeddings to {cand_embed_path}") | |
| # cand_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_img_token_masks.jsonl") | |
| # with open(cand_img_masks_path, 'w', encoding='utf-8') as f: | |
| # for cid, m in zip(all_cand_ids[:len(full_eval_cand_dataset)], cand_img_masks[:len(full_eval_cand_dataset)]): | |
| # f.write(json.dumps({"cand_id": str(cid), "mask": m}, ensure_ascii=False) + "\n") | |
| # print_master(f"Saved candidate image token masks to {cand_img_masks_path}") | |
| # # Save candidate-specific inference statistics | |
| # if cand_total_stats["data_point_count"] > 0: | |
| # final_cand_stats = { | |
| # "task_name": dataset_name, | |
| # "encode_side": "candidate", | |
| # "data_point_count": cand_total_stats["data_point_count"], | |
| # "inference_times": { | |
| # "total_inference_time_seconds": cand_total_stats["total_inference_time_seconds"], | |
| # "avg_inference_time_per_item_seconds": cand_total_stats["total_inference_time_seconds"] / cand_total_stats["data_point_count"], | |
| # "module_average_times_per_call": {}, | |
| # "module_total_times_seconds": {}, | |
| # "module_calls_count": {}, | |
| # }, | |
| # "token_counts": { | |
| # "total_visual_tokens": cand_total_stats["token_counts"]["visual_tokens"], | |
| # "avg_visual_tokens_per_item": cand_total_stats["token_counts"]["visual_tokens"] / cand_total_stats["data_point_count"], | |
| # "total_language_input_tokens_raw": cand_total_stats["token_counts"]["language_input_tokens_raw"], | |
| # "avg_language_input_tokens_raw_per_item": cand_total_stats["token_counts"]["language_input_tokens_raw"] / cand_total_stats["data_point_count"], | |
| # "total_llm_total_input_tokens": cand_total_stats["token_counts"]["llm_total_input_tokens"], | |
| # "avg_llm_total_input_tokens_per_item": cand_total_stats["token_counts"]["llm_total_input_tokens"] / cand_total_stats["data_point_count"], | |
| # "total_language_output_tokens": cand_total_stats["token_counts"]["language_output_tokens"], | |
| # "avg_language_output_tokens_per_item": cand_total_stats["token_counts"]["language_output_tokens"] / cand_total_stats["data_point_count"], | |
| # } | |
| # } | |
| # for module_name, stats in cand_total_stats["module_inference_times"].items(): | |
| # final_cand_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"] | |
| # final_cand_stats["inference_times"]["module_calls_count"][module_name] = stats["count"] | |
| # if stats["count"] > 0: | |
| # final_cand_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"] | |
| # else: | |
| # final_cand_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0 | |
| # with open(cand_inference_stats_path, 'w', encoding='utf-8') as f: | |
| # json.dump(final_cand_stats, f, ensure_ascii=False, indent=4) | |
| # print_master(f"Candidate inference statistics for {dataset_name} saved to: {cand_inference_stats_path}") | |
| # else: | |
| # print_master(f"No candidate data processed for {dataset_name}, skipping candidate inference statistics output.") | |
| # if dist.is_initialized(): | |
| # dist.barrier() | |
| # # --- 3. Compute Scores (on master rank only) --- | |
| # score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") | |
| # #################################################################################### | |
| # pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl") | |
| # score_detail_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_details.jsonl") # 新文件,存相似度分数 | |
| # def append_score_detail(score_detail_list, qid, ranked_indices, score_vector, cand_ids, labels): | |
| # """追加一个 query 的候选分数详情""" | |
| # score_detail_list.append({ | |
| # "qid": int(qid), | |
| # "cand_scores": [ | |
| # {"cand_id": str(cand_ids[i]), "score": float(score_vector[i])} | |
| # for i in ranked_indices | |
| # ], | |
| # "label": labels | |
| # }) | |
| # #################################################################################### | |
| # if local_rank == 0: | |
| # if os.path.exists(score_path): | |
| # try: | |
| # with open(score_path, "r") as f: | |
| # score_dict = json.load(f) | |
| # print_master(f"Score of {dataset_name} (loaded from previous run): {score_path}") | |
| # formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} | |
| # print_master(formatted) | |
| # # No `continue` here, as we want to ensure other files are processed/generated | |
| # except Exception as e: | |
| # print_master(f"Failed to load score for {dataset_name}, proceeding to recompute. Error: {e}") | |
| # # Proceed with score computation if not loaded or failed to load | |
| # with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f) | |
| # with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f) | |
| # gt_infos = [json.loads(l) for l in open(dataset_info_path)] | |
| # pred_dicts = [] | |
| # score_detail_dicts = []################################### | |
| # rank_against_all_candidates = task_config.get("eval_type", "global") == "global" | |
| # # if rank_against_all_candidates: | |
| # # cand_keys = list(cand_embed_dict.keys()) | |
| # # cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys]) | |
| # # # Handle late-interaction scoring | |
| # # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] | |
| # # qry_embed = torch.from_numpy(qry_embeds) | |
| # # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] | |
| # # scores = processor.score(qry_embed, cand_embeds, batch_size=64) # use ColPali score function | |
| # # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist() | |
| # # scores = scores.cpu().numpy() | |
| # # else: # Dense | |
| # # cosine_scores = np.dot(qry_embeds, cand_embeds.T) | |
| # # ranked_candids = np.argsort(-cosine_scores, axis=1) | |
| # ##################################################### | |
| # if rank_against_all_candidates: | |
| # cand_keys = list(cand_embed_dict.keys()) | |
| # cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys]) | |
| # if qry_embeds.ndim == 3: # Late-interaction | |
| # qry_embed_t = torch.from_numpy(qry_embeds) | |
| # cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds] | |
| # sim_matrix = processor.score(qry_embed_t, cand_embeds_t, batch_size=64).cpu().numpy() # [N_q, N_c] | |
| # else: # Dense | |
| # sim_matrix = np.dot(qry_embeds, cand_embeds.T) # [N_q, N_c] | |
| # ranked_all = np.argsort(-sim_matrix, axis=1) | |
| # ######################################################### | |
| # for qid, (ranked_candid, gt_info) in tqdm(enumerate(zip(ranked_candids, gt_infos)), desc=f"Calculating scores for {dataset_name}"): | |
| # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] | |
| # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None | |
| # assert rel_scores is None or len(rel_docids) == len(rel_scores) | |
| # pred_dicts.append({ | |
| # "prediction": [cand_keys[i] for i in ranked_candid], | |
| # "label": rel_docids, | |
| # "rel_scores": rel_scores, | |
| # }) | |
| # ################################# 新增:详细相似度字典 | |
| # append_score_detail(score_detail_dicts, qid, ranked_indices, sim_matrix[qid], cand_keys, rel_docids) | |
| # ######################################## | |
| # # else: | |
| # # for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"): | |
| # # cand_embeds = np.stack([cand_embed_dict[key] for key in gt_info["cand_names"]]) | |
| # # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] | |
| # # qry_embed = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) | |
| # # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] | |
| # # scores = processor.score(qry_embed, cand_embeds, batch_size=1024) # use ColPali score function | |
| # # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()[0] | |
| # # else: | |
| # # cosine_score = np.dot(qry_embed, cand_embeds.T) | |
| # # ranked_candids = np.argsort(-cosine_score) | |
| # # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] | |
| # # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None | |
| # # assert rel_scores is None or len(rel_docids) == len(rel_scores) | |
| # # pred_dicts.append({ | |
| # # "prediction": [gt_info["cand_names"][i] for i in ranked_candids], | |
| # # "label": rel_docids, | |
| # # "rel_scores": rel_scores, | |
| # # }) | |
| # ####################################################################### | |
| # else: # 非全局 | |
| # for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"): | |
| # cand_ids_local = gt_info["cand_names"] | |
| # cand_embeds = np.stack([cand_embed_dict[key] for key in cand_ids_local]) | |
| # if qry_embeds.ndim == 3: # Late-interaction | |
| # qry_embed_t = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # [1, Lq, H] | |
| # cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds] | |
| # sim_vec = processor.score(qry_embed_t, cand_embeds_t, batch_size=1024).cpu().numpy()[0] # [N_c] | |
| # else: # Dense | |
| # sim_vec = np.dot(qry_embed, cand_embeds.T) # [N_c] | |
| # ranked_indices = np.argsort(-sim_vec) | |
| # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] | |
| # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None | |
| # assert rel_scores is None or len(rel_docids) == len(rel_scores) | |
| # pred_dicts.append({ | |
| # "prediction": [cand_ids_local[i] for i in ranked_indices], | |
| # "label": rel_docids, | |
| # "rel_scores": rel_scores, | |
| # }) | |
| # # 新增:分数详情 | |
| # append_score_detail(score_detail_dicts, qid, ranked_indices, sim_vec, cand_ids_local, rel_docids) | |
| # ########################################## 保存预测和分数 | |
| # with open(score_detail_path, "w") as f: # 新增 | |
| # for detail in score_detail_dicts: | |
| # f.write(json.dumps(detail) + '\n') | |
| # print_master(f"Detailed score file saved to: {score_detail_path}") | |
| # metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] | |
| # metrics = RankingMetrics(metrics_to_report) | |
| # score_dict = metrics.evaluate(pred_dicts) | |
| # formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} | |
| # score_dict["num_pred"] = len(pred_dicts) | |
| # score_dict["num_data"] = len(gt_infos) | |
| # print_master(f"Score of {dataset_name}:") | |
| # print_master(formatted) | |
| # print_master(f"Outputting final score to: {score_path}") | |
| # with open(score_path, "w") as f: | |
| # json.dump(score_dict, f, indent=4) | |
| # with open(pred_path, "w") as f: | |
| # for pred in pred_dicts: | |
| # f.write(json.dumps(pred) + '\n') | |
| # #################################################################### | |
| # score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") | |
| # pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl") | |
| # metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] | |
| # metrics = RankingMetrics(metrics_to_report) | |
| # score_dict = metrics.evaluate(pred_dicts) | |
| # formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} | |
| # score_dict["num_pred"] = len(pred_dicts) | |
| # score_dict["num_data"] = len(gt_infos) | |
| # print_master(f"Score of {dataset_name}:") | |
| # print_master(formatted) | |
| # print_master(f"Outputting final score to: {score_path}") | |
| # with open(score_path, "w") as f: | |
| # json.dump(score_dict, f, indent=4) | |
| # with open(pred_path, "w") as f: | |
| # for pred in pred_dicts: | |
| # f.write(json.dumps(pred) + '\n') | |
| # if __name__ == '__main__': | |
| # main() | |
| ############################################################################################ | |
| import datetime | |
| import logging | |
| import json | |
| import random | |
| import time | |
| import numpy as np | |
| import os | |
| import pickle | |
| import sys | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn.functional as F | |
| import yaml | |
| import transformers | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from transformers import HfArgumentParser, AutoConfig, AutoTokenizer | |
| from datasets import Dataset, concatenate_datasets | |
| from datasets.distributed import split_dataset_by_node | |
| from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl_train_tokrnpooling import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src | |
| from src.arguments import ModelArguments, DataArguments, TrainingArguments | |
| from src.data.collator.eval_collator import MultimodalEvalDataCollator | |
| from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset | |
| from src.eval_utils.metrics import RankingMetrics | |
| from src.model.model_cut_layer import MMEBModel | |
| from src.model.processor import get_backbone_name, load_processor, COLPALI | |
| from src.utils import batch_to_device, print_rank, print_master | |
| import math | |
| def get_env_mid_layer(): | |
| v = os.environ.get("MID_LM_LAYER", "").strip() | |
| if v == "" or v.lower() in {"none", "null"}: | |
| return None | |
| try: | |
| return int(v) | |
| except: | |
| logger.warning(f"Invalid MID_LM_LAYER={v}, ignore.") | |
| return None | |
| def get_env_eval_layers(): | |
| """ | |
| 解析环境变量 LM_LAYERS(优先)或兼容旧的 MID_LM_LAYER。 | |
| - LM_LAYERS 示例:"4,8,12,last";可包含 'last'/'none'/'null'/'-1' 表示最后一层(None)。 | |
| - 若未设置 LM_LAYERS,则回落到旧逻辑:MID_LM_LAYER=None -> [None];否则 [mid, None] | |
| 返回: list[ int | None ],例如 [4, 8, 12, None];None 代表最后一层。 | |
| """ | |
| v = os.environ.get("LM_LAYERS", None) | |
| if v is not None: | |
| v = v.strip() | |
| if v: | |
| toks = [t.strip() for t in v.split(',') if t.strip() != ""] | |
| layers = [] | |
| for tok in toks: | |
| tl = tok.lower() | |
| if tl in {"last", "none", "null", "-1"}: | |
| layers.append(None) | |
| else: | |
| try: | |
| val = int(tok) | |
| if val > 0: | |
| layers.append(val) | |
| else: | |
| logger.warning(f"Ignoring non-positive layer '{tok}' in LM_LAYERS.") | |
| except Exception: | |
| logger.warning(f"Invalid token '{tok}' in LM_LAYERS; must be int or 'last'/'none'.") | |
| # 去重但保持顺序 | |
| seen = set() | |
| uniq = [] | |
| for l in layers: | |
| key = -1 if l is None else l | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| uniq.append(l) | |
| if not uniq: | |
| return [None] | |
| return uniq | |
| else: | |
| # 兼容旧逻辑 | |
| mid = get_env_mid_layer() | |
| return [None] if mid is None else [mid, None] | |
| def make_layer_tag(keep_layers: int | None): | |
| return f"layer{keep_layers}" if keep_layers and keep_layers > 0 else "layerlast" | |
| def dot_sim(a: np.ndarray, b: np.ndarray) -> np.ndarray: | |
| # a: [Nq, D], b: [Nc, D], both L2-normalized already if normalize=true | |
| return a @ b.T | |
| def build_score_details(qid: int, cand_ids: list, score_vec: np.ndarray, ranked_indices: np.ndarray): | |
| return { | |
| "qid": int(qid), | |
| "cand_scores": [ | |
| {"cand_id": str(cand_ids[i]), "score": float(score_vec[i])} | |
| for i in ranked_indices | |
| ] | |
| } | |
| def top1_top2_margin(score_vec: np.ndarray) -> float: | |
| if len(score_vec) < 2: | |
| return float("inf") # 只有一个候选时视作极大margin | |
| top2 = np.partition(score_vec, -2)[-2:] | |
| top2.sort() | |
| return float(top2[-1] - top2[-2]) | |
| def simulate_early_exit_by_margin( | |
| sims_mid: list[dict], sims_last: list[dict], labels: list[list[str]], metrics_to_report: list[str], | |
| taus: list[float], rank_global: bool | |
| ): | |
| """ | |
| sims_mid / sims_last: 每个query一个dict: {cand_id: score} | |
| labels: 每个query的正样本cand_id列表 | |
| 返回:不同tau下的覆盖率、指标 | |
| """ | |
| assert len(sims_mid) == len(sims_last) == len(labels) | |
| N = len(labels) | |
| results = [] | |
| from src.eval_utils.metrics import RankingMetrics | |
| metrics = RankingMetrics(metrics_to_report) | |
| # 预构造 用于metrics.evaluate 的pred_dict | |
| def to_pred_dicts(use_mid_mask: list[bool]) -> list[dict]: | |
| pred_dicts = [] | |
| for qid in range(N): | |
| sims_use = sims_mid[qid] if use_mid_mask[qid] else sims_last[qid] | |
| # 排序 | |
| ranked = sorted(sims_use.items(), key=lambda x: -x[1]) | |
| pred_dicts.append({ | |
| "prediction": [cid for cid, _ in ranked], | |
| "label": labels[qid], | |
| "rel_scores": None | |
| }) | |
| return pred_dicts | |
| # 计算中间层margin | |
| margins = [] | |
| for qid in range(N): | |
| # 取前两大分数的margin | |
| if len(sims_mid[qid]) == 0: | |
| margins.append(0.0) | |
| continue | |
| scores = np.array(list(sims_mid[qid].values()), dtype=np.float32) | |
| margins.append(top1_top2_margin(scores)) | |
| margins = np.array(margins, dtype=np.float32) | |
| for tau in taus: | |
| use_mid_mask = (margins >= tau).tolist() | |
| pred_dicts = to_pred_dicts(use_mid_mask) | |
| score_dict = metrics.evaluate(pred_dicts) | |
| coverage = float(np.mean(use_mid_mask)) # 早停覆盖率 | |
| results.append({ | |
| "tau": tau, | |
| "coverage": coverage, | |
| **score_dict | |
| }) | |
| return results | |
| def top1_top2_margin_from_array(score_vec: np.ndarray) -> float: | |
| if score_vec is None or len(score_vec) == 0: | |
| return 0.0 | |
| if len(score_vec) == 1: | |
| return float('inf') | |
| # 取前两大 | |
| top2 = np.partition(score_vec, -2)[-2:] | |
| top2.sort() | |
| return float(top2[-1] - top2[-2]) | |
| logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) --- | |
| timing_info = {} | |
| token_info = { | |
| "vision_tokens": 0, | |
| "text_input_tokens": 0, # Refers to the original text token count | |
| "text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0. | |
| "total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text) | |
| } | |
| # --- Hook Functions Definition --- | |
| def timing_pre_hook(module, input): | |
| module_id = id(module) | |
| if module_id not in timing_info: | |
| timing_info[module_id] = [] | |
| timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__)) | |
| def timing_post_hook(module, input, output): | |
| module_id = id(module) | |
| if module_id not in timing_info: | |
| # print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})") | |
| return | |
| timing_info[module_id].append((time.time(), 'post', module.__class__.__name__)) | |
| # Collect vision token count (only from Vision Transformer module's post hook) | |
| module_name = module.__class__.__name__ | |
| if "vision" in module_name.lower() and "transformer" in module_name.lower(): | |
| if isinstance(output, torch.Tensor): | |
| token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim) | |
| elif hasattr(output, 'last_hidden_state'): | |
| token_info["vision_tokens"] = output.last_hidden_state.shape[1] | |
| def register_model_hooks(model): | |
| registered_modules = [] | |
| core_model = model | |
| # print_master(f"DEBUG: Initial model type in register_model_hooks: {type(model)}") | |
| if hasattr(model, 'encoder') and model.encoder is not None: | |
| # print_master(f"DEBUG: model has 'encoder' attribute. Type of model.encoder: {type(model.encoder)}") | |
| # 使用从 'src' 路径导入的 Qwen2VLForConditionalGeneration 进行检查 | |
| if isinstance(model.encoder, _Qwen2VLForConditionalGeneration_src): | |
| # print_master("Detected MMEBModel structure, registering hooks on model.encoder's sub-modules.") | |
| core_model = model.encoder | |
| else: | |
| print_master(f"WARNING: model.encoder is not an instance of _Qwen2VLForConditionalGeneration_src. Its type is {type(model.encoder)}. Hooks will be registered on top-level model if applicable.") | |
| else: | |
| print_master("WARNING: Model structure does not have an 'encoder' attribute. Registering hooks directly on top-level modules.") | |
| # Vision module | |
| if hasattr(core_model, 'visual') and core_model.visual is not None: | |
| vision_module = core_model.visual | |
| vision_module.register_forward_pre_hook(timing_pre_hook) | |
| vision_module.register_forward_hook(timing_post_hook) | |
| registered_modules.append(vision_module) | |
| print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}") | |
| else: | |
| print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).") | |
| # Merger module (if inside visual) - it's part of the vision component | |
| if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None: | |
| merger_module = core_model.visual.merger | |
| merger_module.register_forward_pre_hook(timing_pre_hook) | |
| merger_module.register_forward_hook(timing_post_hook) | |
| registered_modules.append(merger_module) | |
| print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}") | |
| else: | |
| print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).") | |
| # Language model body | |
| if hasattr(core_model, 'model') and core_model.model is not None: | |
| llm_main_module = core_model.model | |
| llm_main_module.register_forward_pre_hook(timing_pre_hook) | |
| llm_main_module.register_forward_hook(timing_post_hook) | |
| registered_modules.append(llm_main_module) | |
| print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}") | |
| else: | |
| print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).") | |
| # LM Head | |
| if hasattr(core_model, 'lm_head') and core_model.lm_head is not None: | |
| lm_head_module = core_model.lm_head | |
| lm_head_module.register_forward_pre_hook(timing_pre_hook) | |
| lm_head_module.register_forward_hook(timing_post_hook) | |
| registered_modules.append(lm_head_module) | |
| print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}") | |
| else: | |
| print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).") | |
| if not registered_modules: | |
| print_master("Warning: No major modules found for hook registration. Check model architecture.") | |
| return registered_modules | |
| def pad_dataset_to_divisible(dataset, world_size): | |
| num_samples = len(dataset) | |
| if num_samples % world_size == 0: | |
| return dataset, num_samples | |
| num_to_add = world_size - (num_samples % world_size) | |
| padded_size = num_samples + num_to_add | |
| padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) | |
| padded_dataset = concatenate_datasets([dataset, padding_data]) | |
| return padded_dataset, padded_size | |
| def encode_embeddings( | |
| model: MMEBModel, | |
| loader: DataLoader, | |
| training_args: TrainingArguments, | |
| model_args: ModelArguments, | |
| full_dataset: Dataset, | |
| encode_side: str, | |
| description: str = "Encoding" | |
| ) -> tuple[np.ndarray, list, list, list]: # CHANGED: + list for img_token_masks | |
| """ | |
| Encodes embeddings for a given dataset using the model, handling both standard and | |
| late-interaction models in a DDP-safe manner. | |
| Returns: | |
| - embeddings: np.ndarray | |
| - infos_or_ids: list | |
| - batch_stats_list: list | |
| - img_token_masks: list[None | list[bool]] # NEW | |
| """ | |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
| # Check if the model is a late-interaction type | |
| is_late_interaction = (model_args.model_backbone == COLPALI) | |
| local_embeds = [] | |
| local_gt_infos = [] | |
| local_max_len = 0 | |
| # --- New: List to store statistics for each batch --- | |
| batch_stats_list = [] | |
| # --- NEW: Collect image token masks locally --- | |
| local_img_token_masks = [] # 每个样本一个元素:None 或 [bool, ...] | |
| model.eval() | |
| # Register hooks for the model once per encode_embeddings call | |
| registered_hooks = register_model_hooks(model) | |
| # --- NEW: helpers to取mask并序列化 --- | |
| def _search_key(obj, key: str): | |
| # 递归搜索 dict/list/tuple,找到指定 key | |
| if isinstance(obj, dict): | |
| if key in obj: | |
| return obj[key] | |
| for v in obj.values(): | |
| r = _search_key(v, key) | |
| if r is not None: | |
| return r | |
| elif isinstance(obj, (list, tuple)): | |
| for v in obj: | |
| r = _search_key(v, key) | |
| if r is not None: | |
| return r | |
| return None | |
| def _to_serializable_mask_list(mask_list, batch_size: int): | |
| # 将模型返回的 mask(list/tensor/ndarray/None)转成 [None | list[bool]] * B | |
| if mask_list is None: | |
| return [None] * batch_size | |
| out = [] | |
| if isinstance(mask_list, (list, tuple)): | |
| for m in mask_list: | |
| if m is None: | |
| out.append(None) | |
| elif torch.is_tensor(m): | |
| out.append(m.detach().cpu().tolist()) | |
| elif isinstance(m, np.ndarray): | |
| out.append(m.tolist()) | |
| else: | |
| # already python list/bool | |
| out.append(m) | |
| elif torch.is_tensor(mask_list): | |
| # 若是 2D 张量(B, L),直接 tolist() -> list[list[bool/int]] | |
| out = mask_list.detach().cpu().tolist() | |
| elif isinstance(mask_list, np.ndarray): | |
| out = mask_list.tolist() | |
| else: | |
| # 未知类型,保守返回 None 占位 | |
| out = [None] * batch_size | |
| # 长度对齐 batch_size | |
| if isinstance(out, list): | |
| if len(out) < batch_size: | |
| out = out + [None] * (batch_size - len(out)) | |
| elif len(out) > batch_size: | |
| out = out[:batch_size] | |
| return out | |
| with torch.no_grad(): | |
| for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): | |
| # --- Reset statistics for each inference pass --- | |
| timing_info.clear() | |
| token_info["vision_tokens"] = 0 | |
| token_info["text_input_tokens"] = 0 | |
| token_info["text_output_tokens"] = 0 | |
| token_info["total_llm_input_tokens"] = 0 | |
| inputs = batch_to_device(inputs, training_args.device) | |
| current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1 | |
| with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): | |
| start_inference_time = time.time() | |
| if encode_side == "qry": | |
| output = model(qry=inputs) | |
| reps = output["qry_reps"].detach() | |
| local_gt_infos.extend(dataset_info) | |
| else: | |
| output = model(tgt=inputs) | |
| reps = output["tgt_reps"].detach() | |
| local_gt_infos.extend([info["cand_name"] for info in dataset_info]) | |
| end_inference_time = time.time() | |
| # --- NEW: 提取并保存本 batch 的 image_token_bool_masks --- | |
| # 期望 MMEBModel 的 output 中直接或间接包含 'image_token_bool_masks' | |
| img_masks_raw = None | |
| if isinstance(output, dict): | |
| img_masks_raw = _search_key(output, "image_token_bool_masks") | |
| # 可选:若你在 MMEBModel 上挂了属性,也可以尝试读取 | |
| if img_masks_raw is None and hasattr(model, "image_token_bool_masks"): | |
| img_masks_raw = getattr(model, "image_token_bool_masks") | |
| img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size) | |
| local_img_token_masks.extend(img_masks_serializable) | |
| # --- Update total LLM input tokens after the model call --- | |
| if 'input_ids' in inputs and inputs['input_ids'] is not None: | |
| token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1] | |
| token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"] | |
| token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"]) | |
| # --- Collect and Store Batch Statistics --- | |
| batch_inference_time = end_inference_time - start_inference_time | |
| current_batch_stats = { | |
| "batch_size": current_batch_size, | |
| "total_inference_time_seconds": batch_inference_time, | |
| "module_inference_times": {}, | |
| "token_counts": { | |
| "visual_tokens": token_info["vision_tokens"], | |
| "language_input_tokens_raw": token_info["text_input_tokens"], | |
| "llm_total_input_tokens": token_info["total_llm_input_tokens"], | |
| "language_output_tokens": token_info["text_output_tokens"], | |
| } | |
| } | |
| # Calculate and store module timings for the current batch | |
| for module_obj in registered_hooks: | |
| module_id = id(module_obj) | |
| module_name = module_obj.__class__.__name__ | |
| times = timing_info.get(module_id, []) | |
| durations = [] | |
| pre_times = {} | |
| for t, event_type, _ in times: | |
| if event_type == 'pre': | |
| pre_times[module_id] = t | |
| elif event_type == 'post' and module_id in pre_times: | |
| duration = t - pre_times.pop(module_id) | |
| durations.append(duration) | |
| if durations: | |
| current_batch_stats["module_inference_times"][module_name] = { | |
| "total": sum(durations), | |
| "count": len(durations), | |
| "avg": sum(durations) / len(durations) | |
| } | |
| else: | |
| current_batch_stats["module_inference_times"][module_name] = { | |
| "total": 0.0, | |
| "count": 0, | |
| "avg": 0.0 | |
| } | |
| batch_stats_list.append(current_batch_stats) | |
| # --- Debug prints (optional) --- | |
| print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---") | |
| print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds") | |
| print_rank("--- Module Inference Timing Statistics ---") | |
| for module_name, stats in current_batch_stats["module_inference_times"].items(): | |
| print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s") | |
| print_rank("--- Token Count Statistics ---") | |
| print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}") | |
| print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}") | |
| print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}") | |
| print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}") | |
| if is_late_interaction and reps.dim() == 3: | |
| local_max_len = max(local_max_len, reps.shape[1]) | |
| local_embeds.append(reps) | |
| if not local_embeds: | |
| # Handle cases where a rank gets no data | |
| return np.array([]), [], [], [] # CHANGED: 4个返回值 | |
| # === DDP Synchronization and Padding for Late-Interaction Models === | |
| if is_late_interaction: | |
| if dist.is_initialized(): | |
| # 1: global max length | |
| local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device) | |
| dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX) | |
| global_max_len = local_max_len_tensor.item() | |
| else: | |
| global_max_len = local_max_len | |
| # 2: pad to global max length | |
| padded_embeds = [] | |
| for reps_batch in local_embeds: | |
| if reps_batch.dim() == 3: | |
| B, L, H = reps_batch.shape | |
| padding_size = global_max_len - L | |
| padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0) | |
| padded_embeds.append(padded_batch) | |
| else: | |
| padded_embeds.append(reps_batch) | |
| embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous() | |
| else: | |
| embeds_tensor = torch.cat(local_embeds, dim=0).contiguous() | |
| # === Gather embeddings and keys from all ranks === | |
| if dist.is_initialized() and full_dataset.num_rows >= world_size: | |
| print_master(f"Gathering {encode_side} embeddings across all ranks...") | |
| # tensor gather | |
| output_shape = list(embeds_tensor.shape) | |
| output_shape[0] = full_dataset.num_rows | |
| embeds_tensor = embeds_tensor.to(training_args.device) | |
| gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) | |
| dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor) | |
| final_embeddings = gathered_embeds_tensor.cpu().float().numpy() | |
| # object gather for infos and stats | |
| gathered_gt_infos = [None for _ in range(world_size)] | |
| dist.all_gather_object(gathered_gt_infos, local_gt_infos) | |
| all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys] | |
| gathered_batch_stats = [None for _ in range(world_size)] | |
| dist.all_gather_object(gathered_batch_stats, batch_stats_list) | |
| all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats] | |
| # --- NEW: gather masks --- | |
| gathered_masks = [None for _ in range(world_size)] | |
| dist.all_gather_object(gathered_masks, local_img_token_masks) | |
| all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list] | |
| else: | |
| all_gt_infos = local_gt_infos | |
| final_embeddings = embeds_tensor.cpu().float().numpy() | |
| all_batch_stats = batch_stats_list | |
| all_img_token_masks = local_img_token_masks # NEW | |
| return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks # CHANGED | |
| def main(): | |
| # ----------------------- Distributed init ----------------------- | |
| if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): | |
| dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) | |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
| print_master("Distributed init debug info:") | |
| print_master(f"RANK: {os.environ.get('RANK')}") | |
| print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}") | |
| print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}") | |
| print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}") | |
| print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}") | |
| if dist.is_initialized(): | |
| print_rank(f"dist.get_rank(): {dist.get_rank()}") | |
| print_rank(f"dist.get_world_size(): {dist.get_world_size()}") | |
| # 兼容 torchrun 参数 | |
| for arg in sys.argv: | |
| if arg.startswith("--local-rank="): | |
| rank = arg.split("=")[1] | |
| sys.argv.remove(arg) | |
| sys.argv.append('--local_rank') | |
| sys.argv.append(rank) | |
| # ----------------------- Parse args ----------------------- | |
| parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) | |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| model_args: ModelArguments | |
| data_args: DataArguments | |
| training_args: TrainingArguments | |
| os.makedirs(data_args.encode_output_path, exist_ok=True) | |
| # 支持多层评测(优先 LM_LAYERS,兼容 MID_LM_LAYER) | |
| layers_to_eval = get_env_eval_layers() | |
| print_master(f"Eval layers (qry/tgt): {layers_to_eval}") | |
| # ----------------------- Model loading ----------------------- | |
| hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| if not getattr(model_args, "model_backbone", None): | |
| model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) | |
| setattr(model_args, 'model_backbone', model_backbone) | |
| setattr(training_args, 'model_backbone', model_backbone) | |
| print_master(f'Model Backbone: {model_args.model_backbone}') | |
| # 仅 rank0 下载,其他rank等待缓存 | |
| if local_rank == 0: | |
| processor = load_processor(model_args, data_args) | |
| model = MMEBModel.load(model_args, is_trainable=False, processor=processor) | |
| print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...") | |
| if torch.distributed.is_initialized(): | |
| torch.distributed.barrier() | |
| if local_rank != 0: | |
| print_rank(f"Loading the model from cache...") | |
| processor = load_processor(model_args, data_args) | |
| time.sleep(random.randint(2 * local_rank, 3 * local_rank)) | |
| model = MMEBModel.load(model_args, is_trainable=False, processor=processor) | |
| model.eval() | |
| model = model.to(training_args.device, dtype=torch.bfloat16) | |
| # 确保“最后一层”时不裁层(避免类里默认20层的坑) | |
| model.set_inference_layers(qry_layers=None, tgt_layers=None) | |
| with open(data_args.dataset_config, 'r') as yaml_file: | |
| dataset_configs = yaml.safe_load(yaml_file) | |
| # ----------------------- Main evaluation loop ----------------------- | |
| for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()): | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| print_master(f"\n--- Evaluating {dataset_name} ---") | |
| # 根据 data_basedir 修正路径 | |
| if data_args.data_basedir is not None: | |
| for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: | |
| if data_args.data_basedir and task_config.get(key): | |
| task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) | |
| # 构建数据集 | |
| full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) | |
| full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) | |
| eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset | |
| if dist.is_initialized(): | |
| world_size = dist.get_world_size() | |
| padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size) | |
| padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size) | |
| eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size) | |
| eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size) | |
| else: | |
| padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset | |
| # 路径索引 | |
| saved_paths = {} # {(side, tag): path} | |
| # --------- 针对每个层设置(中间层/最后一层)分别编码与保存 --------- | |
| for keep_layers in layers_to_eval: | |
| tag = make_layer_tag(keep_layers) | |
| print_master(f"[{dataset_name}] Start encoding for tag={tag} (keep_layers={keep_layers})") | |
| # 设置模型层数 | |
| model.set_inference_layers(qry_layers=keep_layers, tgt_layers=keep_layers) | |
| # 路径 | |
| query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_{tag}") | |
| cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{tag}") | |
| dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") | |
| query_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_inference_stats_{tag}.json") | |
| cand_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_inference_stats_{tag}.json") | |
| qry_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_img_token_masks_{tag}.jsonl") | |
| cand_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_img_token_masks_{tag}.jsonl") | |
| saved_paths[("qry", tag)] = query_embed_path | |
| saved_paths[("tgt", tag)] = cand_embed_path | |
| do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path) | |
| do_cand = not os.path.exists(cand_embed_path) | |
| # 动态累计统计 | |
| def init_total_stats(): | |
| return { | |
| "total_inference_time_seconds": 0.0, | |
| "module_inference_times": {}, # 动态模块名 -> {"total": float, "count": int} | |
| "token_counts": { | |
| "visual_tokens": 0, | |
| "language_input_tokens_raw": 0, | |
| "llm_total_input_tokens": 0, | |
| "language_output_tokens": 0, | |
| }, | |
| "data_point_count": 0 | |
| } | |
| def accumulate_stats(total_stats, batch_stats): | |
| batch_size = batch_stats["batch_size"] | |
| total_stats["total_inference_time_seconds"] += batch_stats["total_inference_time_seconds"] | |
| # 模块时间 | |
| for mname, mstats in batch_stats["module_inference_times"].items(): | |
| if mname not in total_stats["module_inference_times"]: | |
| total_stats["module_inference_times"][mname] = {"total": 0.0, "count": 0} | |
| total_stats["module_inference_times"][mname]["total"] += mstats.get("total", 0.0) | |
| total_stats["module_inference_times"][mname]["count"] += mstats.get("count", 0) | |
| # token 统计(按样本乘 batch_size 再累积) | |
| total_stats["token_counts"]["visual_tokens"] += batch_stats["token_counts"]["visual_tokens"] * batch_size | |
| total_stats["token_counts"]["language_input_tokens_raw"] += batch_stats["token_counts"]["language_input_tokens_raw"] * batch_size | |
| total_stats["token_counts"]["llm_total_input_tokens"] += batch_stats["token_counts"]["llm_total_input_tokens"] * batch_size | |
| total_stats["token_counts"]["language_output_tokens"] += batch_stats["token_counts"]["language_output_tokens"] * batch_size | |
| total_stats["data_point_count"] += batch_size | |
| def finalize_and_save_stats(total_stats, out_path, task_name, encode_side): | |
| if local_rank != 0: | |
| return | |
| if total_stats["data_point_count"] <= 0: | |
| print_master(f"No data processed for {task_name} [{encode_side}], skip saving stats.") | |
| return | |
| final_stats = { | |
| "task_name": task_name, | |
| "encode_side": encode_side, | |
| "data_point_count": total_stats["data_point_count"], | |
| "inference_times": { | |
| "total_inference_time_seconds": total_stats["total_inference_time_seconds"], | |
| "avg_inference_time_per_item_seconds": total_stats["total_inference_time_seconds"] / max(1, total_stats["data_point_count"]), | |
| "module_average_times_per_call": {}, | |
| "module_total_times_seconds": {}, | |
| "module_calls_count": {}, | |
| }, | |
| "token_counts": { | |
| "total_visual_tokens": total_stats["token_counts"]["visual_tokens"], | |
| "avg_visual_tokens_per_item": total_stats["token_counts"]["visual_tokens"] / max(1, total_stats["data_point_count"]), | |
| "total_language_input_tokens_raw": total_stats["token_counts"]["language_input_tokens_raw"], | |
| "avg_language_input_tokens_raw_per_item": total_stats["token_counts"]["language_input_tokens_raw"] / max(1, total_stats["data_point_count"]), | |
| "total_llm_total_input_tokens": total_stats["token_counts"]["llm_total_input_tokens"], | |
| "avg_llm_total_input_tokens_per_item": total_stats["token_counts"]["llm_total_input_tokens"] / max(1, total_stats["data_point_count"]), | |
| "total_language_output_tokens": total_stats["token_counts"]["language_output_tokens"], | |
| "avg_language_output_tokens_per_item": total_stats["token_counts"]["language_output_tokens"] / max(1, total_stats["data_point_count"]), | |
| } | |
| } | |
| for mname, mstats in total_stats["module_inference_times"].items(): | |
| total = mstats.get("total", 0.0) | |
| count = mstats.get("count", 0) | |
| final_stats["inference_times"]["module_total_times_seconds"][mname] = total | |
| final_stats["inference_times"]["module_calls_count"][mname] = count | |
| final_stats["inference_times"]["module_average_times_per_call"][mname] = (total / count) if count > 0 else 0.0 | |
| with open(out_path, 'w', encoding='utf-8') as f: | |
| json.dump(final_stats, f, ensure_ascii=False, indent=4) | |
| print_master(f"[{task_name}] {encode_side} inference statistics saved to: {out_path}") | |
| # ------- Encode queries ------- | |
| if do_query: | |
| print_master(f"[{tag}] Encoding queries...") | |
| eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") | |
| eval_qry_loader = DataLoader( | |
| eval_qry_dataset, | |
| batch_size=training_args.per_device_eval_batch_size, | |
| collate_fn=eval_qry_collator, | |
| num_workers=training_args.dataloader_num_workers | |
| ) | |
| query_embeds, gt_infos, qry_batch_stats, qry_img_masks = encode_embeddings( | |
| model, eval_qry_loader, training_args, model_args, padded_qry_dataset, | |
| encode_side="qry", description=f"Queries[{tag}] for {dataset_name}" | |
| ) | |
| # 截断到真实长度 | |
| query_embeds = query_embeds[:len(full_eval_qry_dataset)] | |
| gt_infos = gt_infos[:len(full_eval_qry_dataset)] | |
| qry_img_masks = qry_img_masks[:len(full_eval_qry_dataset)] | |
| # 累计统计并保存 | |
| qry_total_stats = init_total_stats() | |
| for bs in qry_batch_stats: | |
| accumulate_stats(qry_total_stats, bs) | |
| if local_rank == 0: | |
| with open(query_embed_path, 'wb') as f: | |
| pickle.dump(query_embeds, f) | |
| # dataset_info 只需写一次;若第一次就写 | |
| if not os.path.exists(dataset_info_path): | |
| with open(dataset_info_path, 'w') as f: | |
| for info in gt_infos: | |
| f.write(json.dumps(info) + '\n') | |
| # 保存 masks | |
| with open(qry_img_masks_path, 'w', encoding='utf-8') as f: | |
| for i, m in enumerate(qry_img_masks): | |
| f.write(json.dumps({"index": i, "mask": m}, ensure_ascii=False) + "\n") | |
| print_master(f"Saved query embeddings to {query_embed_path}") | |
| print_master(f"Saved query image token masks to {qry_img_masks_path}") | |
| finalize_and_save_stats(qry_total_stats, query_inference_stats_path, dataset_name, f"query[{tag}]") | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| # ------- Encode candidates ------- | |
| if do_cand: | |
| print_master(f"[{tag}] Encoding candidates...") | |
| eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") | |
| eval_cand_loader = DataLoader( | |
| eval_cand_dataset, | |
| batch_size=training_args.per_device_eval_batch_size, | |
| collate_fn=eval_cand_collator, | |
| num_workers=training_args.dataloader_num_workers | |
| ) | |
| cand_embeds, all_cand_ids, cand_batch_stats, cand_img_masks = encode_embeddings( | |
| model, eval_cand_loader, training_args, model_args, padded_cand_dataset, | |
| encode_side="cand", description=f"Candidates[{tag}] for {dataset_name}" | |
| ) | |
| cand_embeds = cand_embeds[:len(full_eval_cand_dataset)] | |
| all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)] | |
| cand_img_masks = cand_img_masks[:len(full_eval_cand_dataset)] | |
| cand_total_stats = init_total_stats() | |
| for bs in cand_batch_stats: | |
| accumulate_stats(cand_total_stats, bs) | |
| if local_rank == 0: | |
| cand_embed_dict = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds)} | |
| with open(cand_embed_path, 'wb') as f: | |
| pickle.dump(cand_embed_dict, f) | |
| with open(cand_img_masks_path, 'w', encoding='utf-8') as f: | |
| for cid, m in zip(all_cand_ids, cand_img_masks): | |
| f.write(json.dumps({"cand_id": str(cid), "mask": m}, ensure_ascii=False) + "\n") | |
| print_master(f"Saved candidate embeddings to {cand_embed_path}") | |
| print_master(f"Saved candidate image token masks to {cand_img_masks_path}") | |
| finalize_and_save_stats(cand_total_stats, cand_inference_stats_path, dataset_name, f"candidate[{tag}]") | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| # --------- Scoring per layer + combined + early-exit curve --------- | |
| if local_rank == 0: | |
| dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") | |
| gt_infos = [json.loads(l) for l in open(dataset_info_path)] | |
| rank_against_all_candidates = task_config.get("eval_type", "global") == "global" | |
| metrics_to_report = task_config.get("metrics", ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]) | |
| layer_tags = [make_layer_tag(l) for l in layers_to_eval] | |
| sims_by_layer = {} # tag -> list[ dict(cand_id->score) ] | |
| for tag in layer_tags: | |
| query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_{tag}") | |
| cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{tag}") | |
| with open(query_embed_path, 'rb') as f: | |
| qry_embeds = pickle.load(f) | |
| with open(cand_embed_path, 'rb') as f: | |
| cand_embed_dict = pickle.load(f) | |
| pred_dicts = [] | |
| score_detail_dicts = [] | |
| sims_for_exit = [] | |
| if rank_against_all_candidates: | |
| cand_keys = list(cand_embed_dict.keys()) | |
| cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys]) | |
| if isinstance(qry_embeds, np.ndarray) and qry_embeds.ndim == 3: | |
| # Late-interaction | |
| qry_embed_t = torch.from_numpy(qry_embeds) | |
| cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds] | |
| sim_matrix = processor.score(qry_embed_t, cand_embeds_t, batch_size=64).cpu().numpy() | |
| else: | |
| sim_matrix = np.dot(qry_embeds, cand_embeds.T) | |
| ranked_all = np.argsort(-sim_matrix, axis=1) | |
| for qid, gt_info in tqdm(enumerate(gt_infos), total=len(gt_infos), desc=f"[{tag}] scoring(all) {dataset_name}"): | |
| ranked_indices = ranked_all[qid] | |
| rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] | |
| rel_scores = gt_info.get("rel_scores") | |
| pred_dicts.append({ | |
| "prediction": [cand_keys[i] for i in ranked_indices], | |
| "label": rel_docids, | |
| "rel_scores": rel_scores, | |
| }) | |
| score_detail_dicts.append(build_score_details(qid, cand_keys, sim_matrix[qid], ranked_indices)) | |
| sims_for_exit.append({cand_keys[i]: float(sim_matrix[qid][i]) for i in range(len(cand_keys))}) | |
| else: | |
| # 非全局:每个query用 gt_info["cand_names"] 的子集进行评分 | |
| for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), total=len(gt_infos), desc=f"[{tag}] scoring(local) {dataset_name}"): | |
| cand_ids_local = gt_info["cand_names"] | |
| cand_embeds = np.stack([cand_embed_dict[key] for key in cand_ids_local]) | |
| if isinstance(qry_embeds, np.ndarray) and qry_embeds.ndim == 3: | |
| qry_embed_t = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # [1, Lq, H] | |
| cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds] | |
| sim_vec = processor.score(qry_embed_t, cand_embeds_t, batch_size=1024).cpu().numpy()[0] | |
| else: | |
| sim_vec = np.dot(qry_embed, cand_embeds.T) | |
| ranked_indices = np.argsort(-sim_vec) | |
| rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] | |
| rel_scores = gt_info.get("rel_scores") | |
| pred_dicts.append({ | |
| "prediction": [cand_ids_local[i] for i in ranked_indices], | |
| "label": rel_docids, | |
| "rel_scores": rel_scores, | |
| }) | |
| score_detail_dicts.append(build_score_details(qid, cand_ids_local, sim_vec, ranked_indices)) | |
| sims_for_exit.append({cid: float(s) for cid, s in zip(cand_ids_local, sim_vec.tolist())}) | |
| # 保存每层指标与详情 | |
| layer_score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_{tag}.json") | |
| layer_pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred_{tag}.jsonl") | |
| layer_detail_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_details_{tag}.jsonl") | |
| metrics = RankingMetrics(metrics_to_report) | |
| score_dict = metrics.evaluate(pred_dicts) | |
| score_dict["num_pred"] = len(pred_dicts) | |
| score_dict["num_data"] = len(gt_infos) | |
| with open(layer_score_path, "w") as f: | |
| json.dump(score_dict, f, indent=4) | |
| with open(layer_pred_path, "w") as f: | |
| for pred in pred_dicts: | |
| f.write(json.dumps(pred) + '\n') | |
| with open(layer_detail_path, "w") as f: | |
| for detail in score_detail_dicts: | |
| f.write(json.dumps(detail) + "\n") | |
| print_master(f"[{dataset_name}] {tag} score: " + json.dumps({k: (f"{v:.4f}" if isinstance(v, (int, float)) else v) for k, v in score_dict.items()})) | |
| sims_by_layer[tag] = sims_for_exit | |
| # 合并对比文件 + 早停曲线(仅在存在中间层时) | |
| if len(layer_tags) == 2 and "layerlast" in layer_tags: | |
| mid_tag = [t for t in layer_tags if t != "layerlast"][0] | |
| last_tag = "layerlast" | |
| # 合并详情:每个query包含 mid/last 的cand_scores、top1、margin | |
| combined_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_details_both_layers.jsonl") | |
| with open(combined_path, "w", encoding='utf-8') as f: | |
| for qid in range(len(gt_infos)): | |
| sims_mid = sims_by_layer[mid_tag][qid] | |
| sims_last = sims_by_layer[last_tag][qid] | |
| def top1_cid(sims: dict): | |
| return max(sims.items(), key=lambda x: x[1])[0] if sims else None | |
| def margin_of(sims: dict): | |
| vals = np.array(list(sims.values()), dtype=np.float32) | |
| return top1_top2_margin_from_array(vals) | |
| row = { | |
| "qid": int(qid), | |
| "label": gt_infos[qid]["label_name"] if isinstance(gt_infos[qid]["label_name"], list) else [gt_infos[qid]["label_name"]], | |
| "mid": { | |
| "top1": top1_cid(sims_mid), | |
| "margin": margin_of(sims_mid), | |
| "cand_scores": sims_mid | |
| }, | |
| "last": { | |
| "top1": top1_cid(sims_last), | |
| "margin": margin_of(sims_last), | |
| "cand_scores": sims_last | |
| } | |
| } | |
| f.write(json.dumps(row, ensure_ascii=False) + "\n") | |
| print_master(f"[{dataset_name}] combined details saved to {combined_path}") | |
| # 早停曲线(margin 阈值) | |
| taus = [round(x, 3) for x in np.linspace(0.0, 0.6, 31).tolist()] | |
| labels = [ | |
| gi["label_name"] if isinstance(gi["label_name"], list) else [gi["label_name"]] | |
| for gi in gt_infos | |
| ] | |
| exit_curve = simulate_early_exit_by_margin( | |
| sims_by_layer[mid_tag], sims_by_layer[last_tag], labels, metrics_to_report, taus, rank_against_all_candidates | |
| ) | |
| curve_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_early_exit_curve_margin.json") | |
| with open(curve_path, "w") as f: | |
| json.dump(exit_curve, f, indent=4) | |
| print_master(f"[{dataset_name}] early-exit curve saved to {curve_path}") | |
| # # 合并对比文件 + 早停曲线(与 last 对比),如果存在 last | |
| # last_tag = "layerlast" if "layerlast" in layer_tags else None | |
| # if last_tag is not None: | |
| # # 准备 labels 一次即可 | |
| # labels = [ | |
| # gi["label_name"] if isinstance(gi["label_name"], list) else [gi["label_name"]] | |
| # for gi in gt_infos | |
| # ] | |
| # taus = [round(x, 3) for x in np.linspace(0.0, 0.6, 31).tolist()] | |
| # # 对每个中间层分别与 last 做对比 | |
| # for mid_tag in [t for t in layer_tags if t != last_tag]: | |
| # # 合并详情:每个query包含 mid/last 的cand_scores、top1、margin | |
| # combined_path = os.path.join( | |
| # data_args.encode_output_path, | |
| # f"{dataset_name}_score_details_{mid_tag}_vs_last.jsonl" | |
| # ) | |
| # with open(combined_path, "w", encoding="utf-8") as f: | |
| # for qid in range(len(gt_infos)): | |
| # sims_mid = sims_by_layer[mid_tag][qid] | |
| # sims_last = sims_by_layer[last_tag][qid] | |
| # def top1_cid(sims: dict): | |
| # return max(sims.items(), key=lambda x: x[1])[0] if sims else None | |
| # def margin_of(sims: dict): | |
| # vals = np.array(list(sims.values()), dtype=np.float32) | |
| # return top1_top2_margin_from_array(vals) | |
| # row = { | |
| # "qid": int(qid), | |
| # "label": labels[qid], | |
| # "mid": { | |
| # "top1": top1_cid(sims_mid), | |
| # "margin": margin_of(sims_mid), | |
| # "cand_scores": sims_mid | |
| # }, | |
| # "last": { | |
| # "top1": top1_cid(sims_last), | |
| # "margin": margin_of(sims_last), | |
| # "cand_scores": sims_last | |
| # } | |
| # } | |
| # f.write(json.dumps(row, ensure_ascii=False) + "\n") | |
| # print_master(f"[{dataset_name}] combined details saved to {combined_path} (mid={mid_tag} vs last)") | |
| # # 早停曲线(margin 阈值) | |
| # exit_curve = simulate_early_exit_by_margin( | |
| # sims_by_layer[mid_tag], sims_by_layer[last_tag], labels, metrics_to_report, taus, rank_against_all_candidates | |
| # ) | |
| # curve_path = os.path.join( | |
| # data_args.encode_output_path, | |
| # f"{dataset_name}_early_exit_curve_margin_{mid_tag}_vs_last.json" | |
| # ) | |
| # with open(curve_path, "w") as f: | |
| # json.dump(exit_curve, f, indent=4) | |
| # print_master(f"[{dataset_name}] early-exit curve saved to {curve_path} (mid={mid_tag} vs last)") | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| if __name__ == '__main__': | |
| main() |