code_SAS_VLM2Vec / eval_test_time.py
MgGladys's picture
Add files using upload-large-folder tool
2a40e7a verified
# import datetime
# import logging
# import json
# import random
# import time
# import numpy as np
# import os
# import pickle
# import sys
# import torch
# import torch.distributed as dist
# import torch.nn.functional as F
# import yaml
# from torch.utils.data import DataLoader
# from tqdm import tqdm
# from transformers import HfArgumentParser, AutoConfig
# from datasets import Dataset, concatenate_datasets
# from datasets.distributed import split_dataset_by_node
# from src.arguments import ModelArguments, DataArguments, TrainingArguments
# from src.data.collator.eval_collator import MultimodalEvalDataCollator
# from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
# from src.eval_utils.metrics import RankingMetrics
# from src.model.model import MMEBModel
# from src.model.processor import get_backbone_name, load_processor, COLPALI
# from src.utils import batch_to_device, print_rank, print_master
# import multiprocessing
# from multiprocessing import Pool, cpu_count
# logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
# logger = logging.getLogger(__name__)
# ###############################################
# # 计时开始
# def start_timer(name, timing_dict):
# timing_dict[name] = time.time()
# # 计时结束
# def end_timer(name, timing_dict):
# end_time = time.time()
# if name in timing_dict:
# timing_dict[name] = end_time - timing_dict[name]
# # # 放在 main 函数之前,或者单独放在 utils.py 中
# # def register_hooks(model, timing_dict):
# # # --- vision_encoder hook ---
# # def vision_forward_hook(module, input, output):
# # print_master(f"[vision_encoder] output shape: {output.shape}")
# # print_master(f"[vision_encoder] num_image_tokens: {output.shape[1]}")
# # start_timer("vision_encoder", timing_dict)
# # def vision_forward_post_hook(module, input, output):
# # end_timer("vision_encoder", timing_dict)
# # model.encoder.visual.register_forward_hook(vision_forward_hook)
# # model.encoder.visual.register_forward_hook(vision_forward_post_hook)
# # # --- merger hook ---
# # def merger_forward_hook(module, input, output):
# # print_master(f"[merger] before merger - input shape: {input[0].shape}")
# # print_master(f"[merger] before merger - num_tokens: {input[0].shape[1]}")
# # start_timer("merger", timing_dict)
# # def merger_forward_post_hook(module, input, output):
# # print_master(f"[merger] after merger - output shape: {output.shape}")
# # print_master(f"[merger] after merger - num_tokens: {output.shape[1]}")
# # end_timer("merger", timing_dict)
# # if hasattr(model.encoder.visual, 'merger'):
# # model.encoder.visual.merger.register_forward_hook(merger_forward_hook)
# # model.encoder.visual.merger.register_forward_hook(merger_forward_post_hook)
# # # --- decoder hook ---
# # def decoder_forward_hook(module, input, output):
# # # 这里更新为接收 input 和 output 参数
# # if isinstance(input, tuple) and len(input) > 0:
# # print_master(f"[llm_decoder] input shape: {input[0].shape}")
# # print_master(f"[llm_decoder] total_tokens (image+text): {input[0].shape[1]}")
# # start_timer("llm_decoder", timing_dict)
# # def decoder_forward_post_hook(module, input, output):
# # end_timer("llm_decoder", timing_dict)
# # model.encoder.model.register_forward_hook(decoder_forward_hook)
# # model.encoder.model.register_forward_hook(decoder_forward_post_hook)
# # # --- lm_head hook ---
# # def lm_head_forward_hook(module, input, output):
# # start_timer("lm_head", timing_dict)
# # def lm_head_forward_post_hook(module, input, output):
# # end_timer("lm_head", timing_dict)
# # model.encoder.lm_head.register_forward_hook(lm_head_forward_hook)
# # model.encoder.lm_head.register_forward_hook(lm_head_forward_post_hook)
# def register_timing_hooks(model, timing_dict):
# def make_hooks(name):
# def pre_hook(module, input):
# timing_dict[f"{name}_start"] = time.time()
# def forward_hook(module, input, output):
# elapsed = time.time() - timing_dict[f"{name}_start"]
# timing_dict[name] = elapsed # 记录时间
# print_master(f"[{name}] took {elapsed * 1000:.2f} ms")
# return pre_hook, forward_hook
# # vision encoder
# pre, post = make_hooks("vision_encoder")
# model.encoder.visual.register_forward_pre_hook(pre)
# model.encoder.visual.register_forward_hook(post)
# # merger
# if hasattr(model.encoder.visual, 'merger'):
# pre, post = make_hooks("merger")
# model.encoder.visual.merger.register_forward_pre_hook(pre)
# model.encoder.visual.merger.register_forward_hook(post)
# # decoder
# pre, post = make_hooks("llm_decoder")
# model.encoder.model.register_forward_pre_hook(pre)
# model.encoder.model.register_forward_hook(post)
# # lm_head
# pre, post = make_hooks("lm_head")
# model.encoder.lm_head.register_forward_pre_hook(pre)
# model.encoder.lm_head.register_forward_hook(post)
# #####################################################
# def pad_dataset_to_divisible(dataset, world_size):
# num_samples = len(dataset)
# if num_samples % world_size == 0:
# return dataset, num_samples
# num_to_add = world_size - (num_samples % world_size)
# padded_size = num_samples + num_to_add
# padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
# padded_dataset = concatenate_datasets([dataset, padding_data])
# return padded_dataset, padded_size
# def encode_embeddings(
# model: MMEBModel,
# loader: DataLoader,
# training_args: TrainingArguments,
# model_args: ModelArguments,
# full_dataset: Dataset,
# encode_side: str,
# description: str = "Encoding",
# timing_dict: dict | None = None
# ) -> tuple[np.ndarray, list]:
# """
# Encodes embeddings for a given dataset using the model, handling both standard and
# late-interaction models in a DDP-safe manner.
# """
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# world_size = dist.get_world_size() if dist.is_initialized() else 1
# # Check if the model is a late-interaction type
# is_late_interaction = (model_args.model_backbone == COLPALI)
# local_embeds = []
# local_gt_infos = []
# local_max_len = 0
# model.eval()
# with torch.no_grad():
# for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0):
# inputs = batch_to_device(inputs, training_args.device)
# with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
# # Determine if encoding query or target based on available keys
# if encode_side == "qry":
# output = model(qry=inputs)
# reps = output["qry_reps"].detach()
# local_gt_infos.extend(dataset_info) # to retain all information per query
# else:
# output = model(tgt=inputs)
# reps = output["tgt_reps"].detach()
# local_gt_infos.extend([info["cand_name"] for info in dataset_info]) # to retain ground-truth labels
# if is_late_interaction and reps.dim() == 3:
# local_max_len = max(local_max_len, reps.shape[1])
# local_embeds.append(reps)
# if not local_embeds:
# # Handle cases where a rank gets no data
# return np.array([]), []
# # === DDP Synchronization and Padding for Late-Interaction Models ===
# if is_late_interaction:
# if dist.is_initialized():
# # 1. Find the global maximum sequence length across all ranks
# local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device)
# dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX)
# global_max_len = local_max_len_tensor.item()
# else:
# global_max_len = local_max_len
# # 2. Pad all local embeddings to the global max length
# padded_embeds = []
# for reps_batch in local_embeds:
# if reps_batch.dim() == 3:
# B, L, H = reps_batch.shape
# padding_size = global_max_len - L
# padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0)
# padded_embeds.append(padded_batch)
# else: # Should not happen if model is consistently late-interaction
# padded_embeds.append(reps_batch)
# embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous()
# else: # Standard dense models
# embeds_tensor = torch.cat(local_embeds, dim=0).contiguous()
# # === Gather embeddings and keys from all ranks ===
# if dist.is_initialized() and full_dataset.num_rows >= world_size:
# print_master(f"Gathering {encode_side} embeddings across all ranks...")
# # Use the more efficient all_gather_into_tensor for tensors
# output_shape = list(embeds_tensor.shape)
# output_shape[0] = full_dataset.num_rows
# embeds_tensor = embeds_tensor.to(training_args.device)
# gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device)
# dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor)
# final_embeddings = gathered_embeds_tensor.cpu().float().numpy()
# # Gather metadata, for which all_gather_object is appropriate
# gathered_gt_infos = [None for _ in range(world_size)]
# dist.all_gather_object(gathered_gt_infos, local_gt_infos)
# all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys]
# else:
# all_gt_infos = local_gt_infos
# final_embeddings = embeds_tensor.cpu().float().numpy()
# print_master(f"Timing results for {description}:")
# for k, v in timing_dict.items():
# if not k.startswith('_'):
# print_master(f" {k}: {v:.4f} sec")
# return final_embeddings, all_gt_infos
# def main():
# if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
# dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# world_size = dist.get_world_size() if dist.is_initialized() else 1
# # DEBUG PRINTS for Distributed Setup
# print_master("Distributed init debug info:")
# print_master(f"RANK: {os.environ.get('RANK')}")
# print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}")
# print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}")
# print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}")
# print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}")
# if dist.is_initialized():
# print_rank(f"dist.get_rank(): {dist.get_rank()}")
# print_rank(f"dist.get_world_size(): {dist.get_world_size()}")
# for arg in sys.argv:
# if arg.startswith("--local-rank="):
# rank = arg.split("=")[1]
# sys.argv.remove(arg)
# sys.argv.append('--local_rank')
# sys.argv.append(rank)
# parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
# model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# model_args: ModelArguments
# data_args: DataArguments
# training_args: TrainingArguments
# os.makedirs(data_args.encode_output_path, exist_ok=True)
# # --- Model Loading ---
# hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# if not getattr(model_args, "model_backbone", None):
# model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
# setattr(model_args, 'model_backbone', model_backbone)
# setattr(training_args, 'model_backbone', model_backbone)
# print_master(f'Model Backbone: {model_args.model_backbone}')
# # --- DDP-Safe Model Loading ---
# # Step 1: Only the master process (rank 0) downloads the model.
# if local_rank == 0:
# processor = load_processor(model_args, data_args)
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...")
# # Step 2: All processes wait here. The non-master processes will pause
# # until the master process (rank 0) finishes downloading and exits this barrier.
# if torch.distributed.is_initialized():
# torch.distributed.barrier()
# # Step 3: Now that the model is cached, the non-master processes load it from the local cache.
# if local_rank != 0:
# print_rank(f"Loading the model from cache...")
# processor = load_processor(model_args, data_args)
# time.sleep(random.randint(2 * local_rank, 3 * local_rank))
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# model.eval()
# model = model.to(training_args.device, dtype=torch.bfloat16)
# with open(data_args.dataset_config, 'r') as yaml_file:
# dataset_configs = yaml.safe_load(yaml_file)
# #############################################################################
# import time
# timing_dict = {}
# register_hooks(model, timing_dict) # 注册 hooks,开始计时
# ##############################################################################
# # --- Main Evaluation Loop ---
# for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()):
# # 0. load dataset
# if dist.is_initialized():
# dist.barrier()
# print_master(f"--- Evaluating {dataset_name} ---")
# query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry")
# cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt")
# dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl")
# do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path)
# do_cand = not os.path.exists(cand_embed_path)
# if do_query or do_cand:
# if data_args.data_basedir is not None:
# # Construct full paths for data files if --data_basedir is provided
# for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
# if data_args.data_basedir and task_config.get(key):
# task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
# full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
# full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
# eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# # Pad datasets to be divisible by world_size before splitting
# if dist.is_initialized():
# padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size)
# padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size)
# eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size)
# eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size)
# else:
# padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# # --- 1. Compute Query Embeddings ---
# if do_query:
# print_master("Encoding queries...")
# eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
# eval_qry_loader = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers)
# query_embeds, gt_infos = encode_embeddings(model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries for {dataset_name}", timing_dict=timing_dict)
# query_embeds = query_embeds[:len(full_eval_qry_dataset)] # world_size>1, trim the padded data points
# gt_infos = gt_infos[:len(full_eval_qry_dataset)]
# if local_rank == 0:
# with open(query_embed_path, 'wb') as f:
# pickle.dump(query_embeds, f)
# with open(dataset_info_path, 'w') as f:
# for info in gt_infos:
# f.write(json.dumps(info) + '\n')
# print_master(f"Saved query embeddings to {query_embed_path}")
# if dist.is_initialized():
# dist.barrier()
# # --- 2. Compute Candidate Embeddings ---
# if do_cand:
# print_master("Encoding candidates...")
# eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
# eval_cand_loader = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers)
# cand_embeds, all_cand_ids = encode_embeddings(model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates for {dataset_name}", timing_dict=timing_dict)
# cand_embeds = cand_embeds[:len(full_eval_cand_dataset)] # world_size>1, trim the padded data points
# all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)]
# if local_rank == 0:
# cand_embed_dict = {cand_id: embed for cand_id, embed in zip(all_cand_ids, cand_embeds)}
# with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f)
# print_master(f"Saved candidate embeddings to {cand_embed_path}")
# if dist.is_initialized():
# dist.barrier()
# # --- 3. Compute Scores (on master rank only) ---
# if local_rank == 0:
# score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json")
# if os.path.exists(score_path):
# try:
# with open(score_path, "r") as f:
# score_dict = json.load(f)
# print_master(f"Score of {dataset_name} (loaded from previous run): {score_path}")
# formatted = {k: f"{v:.4f}" for k, v in score_dict.items()}
# print_master(formatted)
# continue
# except Exception as e:
# print_master(f"Failed to load score for {dataset_name}, skipping {dataset_name}")
# with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f)
# with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f)
# gt_infos = [json.loads(l) for l in open(dataset_info_path)]
# pred_dicts = []
# rank_against_all_candidates = task_config.get("eval_type", "global") == "global"
# if rank_against_all_candidates:
# cand_keys = list(cand_embed_dict.keys())
# cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys])
# # Handle late-interaction scoring
# if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H]
# qry_embed = torch.from_numpy(qry_embeds)
# cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds]
# scores = processor.score(qry_embed, cand_embeds, batch_size=64) # use ColPali score function
# ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()
# else: # Dense
# cosine_scores = np.dot(qry_embeds, cand_embeds.T)
# ranked_candids = np.argsort(-cosine_scores, axis=1)
# for qid, (ranked_candid, gt_info) in tqdm(enumerate(zip(ranked_candids, gt_infos)), desc=f"Calculating scores for {dataset_name}"):
# rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
# rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None
# assert rel_scores is None or len(rel_docids) == len(rel_scores)
# pred_dicts.append({
# "prediction": [cand_keys[i] for i in ranked_candid],
# "label": rel_docids,
# "rel_scores": rel_scores,
# })
# else:
# for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"):
# cand_embeds = np.stack([cand_embed_dict[key] for key in gt_info["cand_names"]])
# if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H]
# qry_embed = torch.from_numpy(np.array(qry_embed)).unsqueeze(0)
# cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds]
# scores = processor.score(qry_embed, cand_embeds, batch_size=1024) # use ColPali score function
# ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()[0]
# else:
# cosine_score = np.dot(qry_embed, cand_embeds.T)
# ranked_candids = np.argsort(-cosine_score)
# rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
# rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None
# assert rel_scores is None or len(rel_docids) == len(rel_scores)
# pred_dicts.append({
# "prediction": [gt_info["cand_names"][i] for i in ranked_candids],
# "label": rel_docids,
# "rel_scores": rel_scores,
# })
# score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json")
# pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl")
# metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
# metrics = RankingMetrics(metrics_to_report)
# score_dict = metrics.evaluate(pred_dicts)
# formatted = {k: f"{v:.4f}" for k, v in score_dict.items()}
# score_dict["num_pred"] = len(pred_dicts)
# score_dict["num_data"] = len(gt_infos)
# print_master(f"Score of {dataset_name}:")
# print_master(formatted)
# print_master(f"Outputting final score to: {score_path}")
# with open(score_path, "w") as f:
# json.dump(score_dict, f, indent=4)
# with open(pred_path, "w") as f:
# for pred in pred_dicts:
# f.write(json.dumps(pred) + '\n')
# if __name__ == "__main__":
# main()
###################################################################################################
#直接打印输出对应模块时间和token数量
# import datetime
# import logging
# import json
# import random
# import time
# import numpy as np
# import os
# import pickle
# import sys
# import torch
# import torch.distributed as dist
# import torch.nn.functional as F
# import yaml
# import transformers
# from torch.utils.data import DataLoader
# from tqdm import tqdm
# from transformers import HfArgumentParser, AutoConfig, AutoTokenizer#, Qwen2VLForConditionalGeneration
# from datasets import Dataset, concatenate_datasets
# from datasets.distributed import split_dataset_by_node
# from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src
# from src.arguments import ModelArguments, DataArguments, TrainingArguments
# from src.data.collator.eval_collator import MultimodalEvalDataCollator
# from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
# from src.eval_utils.metrics import RankingMetrics
# from src.model.model import MMEBModel
# from src.model.processor import get_backbone_name, load_processor, COLPALI
# from src.utils import batch_to_device, print_rank, print_master
# import multiprocessing
# from multiprocessing import Pool, cpu_count
# logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
# logger = logging.getLogger(__name__)
# # --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) ---
# timing_info = {}
# token_info = {
# "vision_tokens": 0,
# "text_input_tokens": 0, # Refers to the original text token count
# "text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0.
# "total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text)
# }
# # --- Hook Functions Definition ---
# def timing_pre_hook(module, input):
# module_id = id(module)
# if module_id not in timing_info:
# timing_info[module_id] = []
# timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__))
# def timing_post_hook(module, input, output):
# module_id = id(module)
# if module_id not in timing_info:
# print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})")
# return
# timing_info[module_id].append((time.time(), 'post', module.__class__.__name__))
# # Collect vision token count (only from Vision Transformer module's post hook)
# module_name = module.__class__.__name__
# if "vision" in module_name.lower() and "transformer" in module_name.lower():
# if isinstance(output, torch.Tensor):
# token_info["vision_tokens"] = output.shape[0]
# elif hasattr(output, 'last_hidden_state'):
# token_info["vision_tokens"] = output.last_hidden_state.shape[1]
# # --- Hook Functions Definition ---
# # (timing_pre_hook and timing_post_hook remain as previously corrected with debug prints)
# def register_model_hooks(model):
# registered_modules = []
# core_model = model
# print_master(f"DEBUG: Initial model type in register_model_hooks: {type(model)}")
# if hasattr(model, 'encoder') and model.encoder is not None:
# print_master(f"DEBUG: model has 'encoder' attribute. Type of model.encoder: {type(model.encoder)}")
# # 使用从 'src' 路径导入的 Qwen2VLForConditionalGeneration 进行检查
# if isinstance(model.encoder, _Qwen2VLForConditionalGeneration_src):
# print_master("Detected MMEBModel structure, registering hooks on model.encoder's sub-modules.")
# core_model = model.encoder
# else:
# print_master(f"WARNING: model.encoder is not an instance of _Qwen2VLForConditionalGeneration_src. Its type is {type(model.encoder)}. Hooks will be registered on top-level model if applicable.")
# else:
# print_master("WARNING: Model structure does not have an 'encoder' attribute. Registering hooks directly on top-level modules.")
# # Vision module
# if hasattr(core_model, 'visual') and core_model.visual is not None:
# vision_module = core_model.visual
# vision_module.register_forward_pre_hook(timing_pre_hook)
# vision_module.register_forward_hook(timing_post_hook)
# registered_modules.append(vision_module)
# print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).")
# # Merger module (if inside visual) - it's part of the vision component
# if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None:
# merger_module = core_model.visual.merger
# merger_module.register_forward_pre_hook(timing_pre_hook)
# merger_module.register_forward_hook(timing_post_hook)
# registered_modules.append(merger_module)
# print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).")
# # Language model body
# if hasattr(core_model, 'model') and core_model.model is not None:
# llm_main_module = core_model.model
# llm_main_module.register_forward_pre_hook(timing_pre_hook)
# llm_main_module.register_forward_hook(timing_post_hook)
# registered_modules.append(llm_main_module)
# print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).")
# # LM Head
# if hasattr(core_model, 'lm_head') and core_model.lm_head is not None:
# lm_head_module = core_model.lm_head
# lm_head_module.register_forward_pre_hook(timing_pre_hook)
# lm_head_module.register_forward_hook(timing_post_hook)
# registered_modules.append(lm_head_module)
# print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).")
# if not registered_modules:
# print_master("Warning: No major modules found for hook registration. Check model architecture.")
# return registered_modules
# def pad_dataset_to_divisible(dataset, world_size):
# num_samples = len(dataset)
# if num_samples % world_size == 0:
# return dataset, num_samples
# num_to_add = world_size - (num_samples % world_size)
# padded_size = num_samples + num_to_add
# padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
# padded_dataset = concatenate_datasets([dataset, padding_data])
# return padded_dataset, padded_size
# def encode_embeddings(
# model: MMEBModel,
# loader: DataLoader,
# training_args: TrainingArguments,
# model_args: ModelArguments,
# full_dataset: Dataset,
# encode_side: str,
# description: str = "Encoding"
# ) -> tuple[np.ndarray, list]:
# """
# Encodes embeddings for a given dataset using the model, handling both standard and
# late-interaction models in a DDP-safe manner.
# """
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# world_size = dist.get_world_size() if dist.is_initialized() else 1
# # Check if the model is a late-interaction type
# is_late_interaction = (model_args.model_backbone == COLPALI)
# local_embeds = []
# local_gt_infos = []
# local_max_len = 0
# model.eval()
# # Register hooks for the model once per encode_embeddings call
# # This assumes `model` is the MMEBModel instance that wraps the actual HuggingFace model
# # You might need to adjust this if MMEBModel internally manages multiple sub-models
# registered_hooks = register_model_hooks(model) # <--- FIX: Assign the return value here
# # Initialize a tokenizer for text token counting (needs to be from the same model path)
# temp_tokenizer = AutoTokenizer.from_pretrained(model_args.model_name)
# with torch.no_grad():
# for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0):
# # --- Reset statistics for each inference pass ---
# timing_info.clear()
# token_info["vision_tokens"] = 0
# token_info["text_input_tokens"] = 0
# token_info["text_output_tokens"] = 0 # Encoding doesn't generate text output tokens
# token_info["total_llm_input_tokens"] = 0
# inputs = batch_to_device(inputs, training_args.device)
# with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
# # Determine if encoding query or target based on available keys
# # This is where the forward pass happens, triggering hooks
# start_inference_time = time.time()
# if encode_side == "qry":
# output = model(qry=inputs)
# reps = output["qry_reps"].detach()
# local_gt_infos.extend(dataset_info) # to retain all information per query
# else:
# output = model(tgt=inputs)
# reps = output["tgt_reps"].detach()
# local_gt_infos.extend([info["cand_name"] for info in dataset_info]) # to retain ground-truth labels
# end_inference_time = time.time()
# # --- Update total LLM input tokens after the model call ---
# # This requires knowing which part of `inputs` corresponds to the LLM's full input.
# # Assuming `inputs.input_ids` directly goes into the LLM part of Qwen2-VL.
# if 'input_ids' in inputs and inputs['input_ids'] is not None:
# token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1]
# # Approximation for text_input_tokens (if not explicitly available from collator)
# # This assumes visual tokens are a prefix and the rest are text/special tokens.
# token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"]
# # Ensure it's not negative
# token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"])
# # --- Print Inference Timing and Token Statistics per Batch ---
# print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---")
# print_rank(f"Batch Inference took: {end_inference_time - start_inference_time:.4f} seconds")
# # Calculate and print module timings
# print_rank("--- Module Inference Timing Statistics ---")
# for module_obj in registered_hooks:
# module_id = id(module_obj)
# module_name = module_obj.__class__.__name__
# times = timing_info.get(module_id, [])
# durations = []
# pre_times = {} # Store start times for each pre-hook
# for t, event_type, _ in times:
# if event_type == 'pre':
# pre_times[module_id] = t
# elif event_type == 'post' and module_id in pre_times:
# duration = t - pre_times.pop(module_id)
# durations.append(duration)
# if durations:
# print_rank(f"**{module_name}**: Total: {sum(durations):.6f}s, Count: {len(durations)}, Avg: {sum(durations)/len(durations):.6f}s")
# else:
# print_rank(f"**{module_name}**: No complete timing data found for this batch.")
# print_rank("--- Token Count Statistics ---")
# print_rank(f"**视觉 token 数量**: {token_info['vision_tokens']}")
# print_rank(f"**语言输入 token 数量 (仅原始文本)**: {token_info['text_input_tokens']}")
# print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {token_info['total_llm_input_tokens']}")
# print_rank(f"**语言输出 token 数量**: {token_info['text_output_tokens']}") # Will be 0 for encoding
# if is_late_interaction and reps.dim() == 3:
# local_max_len = max(local_max_len, reps.shape[1])
# local_embeds.append(reps)
# if not local_embeds:
# # Handle cases where a rank gets no data
# return np.array([]), []
# # === DDP Synchronization and Padding for Late-Interaction Models ===
# if is_late_interaction:
# if dist.is_initialized():
# # 1. Find the global maximum sequence length across all ranks
# local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device)
# dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX)
# global_max_len = local_max_len_tensor.item()
# else:
# global_max_len = local_max_len
# # 2. Pad all local embeddings to the global max length
# padded_embeds = []
# for reps_batch in local_embeds:
# if reps_batch.dim() == 3:
# B, L, H = reps_batch.shape
# padding_size = global_max_len - L
# padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0)
# padded_embeds.append(padded_batch)
# else: # Should not happen if model is consistently late-interaction
# padded_embeds.append(reps_batch)
# embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous()
# else: # Standard dense models
# embeds_tensor = torch.cat(local_embeds, dim=0).contiguous()
# # === Gather embeddings and keys from all ranks ===
# if dist.is_initialized() and full_dataset.num_rows >= world_size:
# print_master(f"Gathering {encode_side} embeddings across all ranks...")
# # Use the more efficient all_gather_into_tensor for tensors
# output_shape = list(embeds_tensor.shape)
# output_shape[0] = full_dataset.num_rows
# embeds_tensor = embeds_tensor.to(training_args.device)
# gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device)
# dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor)
# final_embeddings = gathered_embeds_tensor.cpu().float().numpy()
# # Gather metadata, for which all_gather_object is appropriate
# gathered_gt_infos = [None for _ in range(world_size)]
# dist.all_gather_object(gathered_gt_infos, local_gt_infos)
# all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys]
# else:
# all_gt_infos = local_gt_infos
# final_embeddings = embeds_tensor.cpu().float().numpy()
# return final_embeddings, all_gt_infos
# def main():
# if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
# dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# world_size = dist.get_world_size() if dist.is_initialized() else 1
# # DEBUG PRINTS for Distributed Setup
# print_master("Distributed init debug info:")
# print_master(f"RANK: {os.environ.get('RANK')}")
# print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}")
# print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}")
# print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}")
# print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}")
# if dist.is_initialized():
# print_rank(f"dist.get_rank(): {dist.get_rank()}")
# print_rank(f"dist.get_world_size(): {dist.get_world_size()}")
# for arg in sys.argv:
# if arg.startswith("--local-rank="):
# rank = arg.split("=")[1]
# sys.argv.remove(arg)
# sys.argv.append('--local_rank')
# sys.argv.append(rank)
# parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
# model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# model_args: ModelArguments
# data_args: DataArguments
# training_args: TrainingArguments
# os.makedirs(data_args.encode_output_path, exist_ok=True)
# # --- Model Loading ---
# hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# if not getattr(model_args, "model_backbone", None):
# model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
# setattr(model_args, 'model_backbone', model_backbone)
# setattr(training_args, 'model_backbone', model_backbone)
# print_master(f'Model Backbone: {model_args.model_backbone}')
# # --- DDP-Safe Model Loading ---
# # Step 1: Only the master process (rank 0) downloads the model.
# if local_rank == 0:
# processor = load_processor(model_args, data_args)
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...")
# # Step 2: All processes wait here. The non-master processes will pause
# # until the master process (rank 0) finishes downloading and exits this barrier.
# if torch.distributed.is_initialized():
# torch.distributed.barrier()
# # Step 3: Now that the model is cached, the non-master processes load it from the local cache.
# if local_rank != 0:
# print_rank(f"Loading the model from cache...")
# processor = load_processor(model_args, data_args)
# time.sleep(random.randint(2 * local_rank, 3 * local_rank))
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# model.eval()
# model = model.to(training_args.device, dtype=torch.bfloat16)
# with open(data_args.dataset_config, 'r') as yaml_file:
# dataset_configs = yaml.safe_load(yaml_file)
# # --- Main Evaluation Loop ---
# for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()):
# # 0. load dataset
# if dist.is_initialized():
# dist.barrier()
# print_master(f"--- Evaluating {dataset_name} ---")
# query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry")
# cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt")
# dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl")
# do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path)
# do_cand = not os.path.exists(cand_embed_path)
# if do_query or do_cand:
# if data_args.data_basedir is not None:
# # Construct full paths for data files if --data_basedir is provided
# for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
# if data_args.data_basedir and task_config.get(key):
# task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
# full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
# full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
# eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# # Pad datasets to be divisible by world_size before splitting
# if dist.is_initialized():
# padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size)
# padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size)
# eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size)
# eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size)
# else:
# padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# # --- 1. Compute Query Embeddings ---
# if do_query:
# print_master("Encoding queries...")
# eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
# eval_qry_loader = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers)
# query_embeds, gt_infos = encode_embeddings(model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries for {dataset_name}")
# query_embeds = query_embeds[:len(full_eval_qry_dataset)] # world_size>1, trim the padded data points
# gt_infos = gt_infos[:len(full_eval_qry_dataset)]
# if local_rank == 0:
# with open(query_embed_path, 'wb') as f:
# pickle.dump(query_embeds, f)
# with open(dataset_info_path, 'w') as f:
# for info in gt_infos:
# f.write(json.dumps(info) + '\n')
# print_master(f"Saved query embeddings to {query_embed_path}")
# if dist.is_initialized():
# dist.barrier()
# # --- 2. Compute Candidate Embeddings ---
# if do_cand:
# print_master("Encoding candidates...")
# eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
# eval_cand_loader = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers)
# cand_embeds, all_cand_ids = encode_embeddings(model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates for {dataset_name}")
# cand_embeds = cand_embeds[:len(full_eval_cand_dataset)] # world_size>1, trim the padded data points
# all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)]
# if local_rank == 0:
# cand_embed_dict = {cand_id: embed for cand_id, embed in zip(all_cand_ids, cand_embeds)}
# with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f)
# print_master(f"Saved candidate embeddings to {cand_embed_path}")
# if dist.is_initialized():
# dist.barrier()
# # --- 3. Compute Scores (on master rank only) ---
# if local_rank == 0:
# score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json")
# if os.path.exists(score_path):
# try:
# with open(score_path, "r") as f:
# score_dict = json.load(f)
# print_master(f"Score of {dataset_name} (loaded from previous run): {score_path}")
# formatted = {k: f"{v:.4f}" for k, v in score_dict.items()}
# print_master(formatted)
# continue
# except Exception as e:
# print_master(f"Failed to load score for {dataset_name}, skipping {dataset_name}")
# with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f)
# with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f)
# gt_infos = [json.loads(l) for l in open(dataset_info_path)]
# pred_dicts = []
# rank_against_all_candidates = task_config.get("eval_type", "global") == "global"
# if rank_against_all_candidates:
# cand_keys = list(cand_embed_dict.keys())
# cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys])
# # Handle late-interaction scoring
# if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H]
# qry_embed = torch.from_numpy(qry_embeds)
# cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds]
# scores = processor.score(qry_embed, cand_embeds, batch_size=64) # use ColPali score function
# ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()
# else: # Dense
# cosine_scores = np.dot(qry_embeds, cand_embeds.T)
# ranked_candids = np.argsort(-cosine_scores, axis=1)
# for qid, (ranked_candid, gt_info) in tqdm(enumerate(zip(ranked_candids, gt_infos)), desc=f"Calculating scores for {dataset_name}"):
# rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
# rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None
# assert rel_scores is None or len(rel_docids) == len(rel_scores)
# pred_dicts.append({
# "prediction": [cand_keys[i] for i in ranked_candid],
# "label": rel_docids,
# "rel_scores": rel_scores,
# })
# else:
# for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"):
# cand_embeds = np.stack([cand_embed_dict[key] for key in gt_info["cand_names"]])
# if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H]
# qry_embed = torch.from_numpy(np.array(qry_embed)).unsqueeze(0)
# cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds]
# scores = processor.score(qry_embed, cand_embeds, batch_size=1024) # use ColPali score function
# ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()[0]
# else:
# cosine_score = np.dot(qry_embed, cand_embeds.T)
# ranked_candids = np.argsort(-cosine_score)
# rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
# rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None
# assert rel_scores is None or len(rel_docids) == len(rel_scores)
# pred_dicts.append({
# "prediction": [gt_info["cand_names"][i] for i in ranked_candids],
# "label": rel_docids,
# "rel_scores": rel_scores,
# })
# score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json")
# pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl")
# metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
# metrics = RankingMetrics(metrics_to_report)
# score_dict = metrics.evaluate(pred_dicts)
# formatted = {k: f"{v:.4f}" for k, v in score_dict.items()}
# score_dict["num_pred"] = len(pred_dicts)
# score_dict["num_data"] = len(gt_infos)
# print_master(f"Score of {dataset_name}:")
# print_master(formatted)
# print_master(f"Outputting final score to: {score_path}")
# with open(score_path, "w") as f:
# json.dump(score_dict, f, indent=4)
# with open(pred_path, "w") as f:
# for pred in pred_dicts:
# f.write(json.dumps(pred) + '\n')
# if __name__ == "__main__":
# main()
##################################################################################################
###################################################################################################
#将每个任务的平均值和总值输出保存到文件中
# import datetime
# import logging
# import json
# import random
# import time
# import numpy as np
# import os
# import pickle
# import sys
# import torch
# import torch.distributed as dist
# import torch.nn.functional as F
# import yaml
# import transformers
# from torch.utils.data import DataLoader
# from tqdm import tqdm
# from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
# from datasets import Dataset, concatenate_datasets
# from datasets.distributed import split_dataset_by_node
# from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src
# from src.arguments import ModelArguments, DataArguments, TrainingArguments
# from src.data.collator.eval_collator import MultimodalEvalDataCollator
# from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
# from src.eval_utils.metrics import RankingMetrics
# from src.model.model import MMEBModel
# from src.model.processor import get_backbone_name, load_processor, COLPALI
# from src.utils import batch_to_device, print_rank, print_master
# logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
# logger = logging.getLogger(__name__)
# # --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) ---
# timing_info = {}
# token_info = {
# "vision_tokens": 0,
# "text_input_tokens": 0, # Refers to the original text token count
# "text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0.
# "total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text)
# }
# # --- Hook Functions Definition ---
# def timing_pre_hook(module, input):
# module_id = id(module)
# if module_id not in timing_info:
# timing_info[module_id] = []
# timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__))
# def timing_post_hook(module, input, output):
# module_id = id(module)
# if module_id not in timing_info:
# # print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})")
# return
# timing_info[module_id].append((time.time(), 'post', module.__class__.__name__))
# # Collect vision token count (only from Vision Transformer module's post hook)
# module_name = module.__class__.__name__
# if "vision" in module_name.lower() and "transformer" in module_name.lower():
# if isinstance(output, torch.Tensor):
# token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim)
# elif hasattr(output, 'last_hidden_state'):
# token_info["vision_tokens"] = output.last_hidden_state.shape[1]
# def register_model_hooks(model):
# registered_modules = []
# core_model = model
# # print_master(f"DEBUG: Initial model type in register_model_hooks: {type(model)}")
# if hasattr(model, 'encoder') and model.encoder is not None:
# # print_master(f"DEBUG: model has 'encoder' attribute. Type of model.encoder: {type(model.encoder)}")
# # 使用从 'src' 路径导入的 Qwen2VLForConditionalGeneration 进行检查
# if isinstance(model.encoder, _Qwen2VLForConditionalGeneration_src):
# # print_master("Detected MMEBModel structure, registering hooks on model.encoder's sub-modules.")
# core_model = model.encoder
# else:
# print_master(f"WARNING: model.encoder is not an instance of _Qwen2VLForConditionalGeneration_src. Its type is {type(model.encoder)}. Hooks will be registered on top-level model if applicable.")
# else:
# print_master("WARNING: Model structure does not have an 'encoder' attribute. Registering hooks directly on top-level modules.")
# # Vision module
# if hasattr(core_model, 'visual') and core_model.visual is not None:
# vision_module = core_model.visual
# vision_module.register_forward_pre_hook(timing_pre_hook)
# vision_module.register_forward_hook(timing_post_hook)
# registered_modules.append(vision_module)
# print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).")
# # Merger module (if inside visual) - it's part of the vision component
# if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None:
# merger_module = core_model.visual.merger
# merger_module.register_forward_pre_hook(timing_pre_hook)
# merger_module.register_forward_hook(timing_post_hook)
# registered_modules.append(merger_module)
# print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).")
# # Language model body
# if hasattr(core_model, 'model') and core_model.model is not None:
# llm_main_module = core_model.model
# llm_main_module.register_forward_pre_hook(timing_pre_hook)
# llm_main_module.register_forward_hook(timing_post_hook)
# registered_modules.append(llm_main_module)
# print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).")
# # LM Head
# if hasattr(core_model, 'lm_head') and core_model.lm_head is not None:
# lm_head_module = core_model.lm_head
# lm_head_module.register_forward_pre_hook(timing_pre_hook)
# lm_head_module.register_forward_hook(timing_post_hook)
# registered_modules.append(lm_head_module)
# print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).")
# if not registered_modules:
# print_master("Warning: No major modules found for hook registration. Check model architecture.")
# return registered_modules
# def pad_dataset_to_divisible(dataset, world_size):
# num_samples = len(dataset)
# if num_samples % world_size == 0:
# return dataset, num_samples
# num_to_add = world_size - (num_samples % world_size)
# padded_size = num_samples + num_to_add
# padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
# padded_dataset = concatenate_datasets([dataset, padding_data])
# return padded_dataset, padded_size
# def encode_embeddings(
# model: MMEBModel,
# loader: DataLoader,
# training_args: TrainingArguments,
# model_args: ModelArguments,
# full_dataset: Dataset,
# encode_side: str,
# description: str = "Encoding"
# ) -> tuple[np.ndarray, list, list]: # Added list to return type for batch_stats
# """
# Encodes embeddings for a given dataset using the model, handling both standard and
# late-interaction models in a DDP-safe manner.
# """
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# world_size = dist.get_world_size() if dist.is_initialized() else 1
# # Check if the model is a late-interaction type
# is_late_interaction = (model_args.model_backbone == COLPALI)
# local_embeds = []
# local_gt_infos = []
# local_max_len = 0
# # --- New: List to store statistics for each batch ---
# batch_stats_list = []
# model.eval()
# # Register hooks for the model once per encode_embeddings call
# registered_hooks = register_model_hooks(model)
# with torch.no_grad():
# for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0):
# # --- Reset statistics for each inference pass ---
# timing_info.clear()
# token_info["vision_tokens"] = 0
# token_info["text_input_tokens"] = 0
# token_info["text_output_tokens"] = 0
# token_info["total_llm_input_tokens"] = 0
# inputs = batch_to_device(inputs, training_args.device)
# current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else 1 # Determine actual batch size
# with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
# start_inference_time = time.time()
# if encode_side == "qry":
# output = model(qry=inputs)
# reps = output["qry_reps"].detach()
# local_gt_infos.extend(dataset_info)
# else:
# output = model(tgt=inputs)
# reps = output["tgt_reps"].detach()
# local_gt_infos.extend([info["cand_name"] for info in dataset_info])
# end_inference_time = time.time()
# # --- Update total LLM input tokens after the model call ---
# if 'input_ids' in inputs and inputs['input_ids'] is not None:
# # `inputs['input_ids'].shape[1]` gives the sequence length,
# # which is the number of tokens per item in the batch.
# # To get total tokens for the batch, multiply by batch size.
# token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1]
# # Approximation for text_input_tokens
# token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"]
# token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"]) # Ensure not negative
# # --- Collect and Store Batch Statistics ---
# batch_inference_time = end_inference_time - start_inference_time
# current_batch_stats = {
# "batch_size": current_batch_size,
# "total_inference_time_seconds": batch_inference_time,
# "module_inference_times": {},
# "token_counts": {
# "visual_tokens": token_info["vision_tokens"],
# "language_input_tokens_raw": token_info["text_input_tokens"],
# "llm_total_input_tokens": token_info["total_llm_input_tokens"],
# "language_output_tokens": token_info["text_output_tokens"],
# }
# }
# # Calculate and store module timings for the current batch
# for module_obj in registered_hooks:
# module_id = id(module_obj)
# module_name = module_obj.__class__.__name__
# times = timing_info.get(module_id, [])
# durations = []
# pre_times = {}
# for t, event_type, _ in times:
# if event_type == 'pre':
# pre_times[module_id] = t
# elif event_type == 'post' and module_id in pre_times:
# duration = t - pre_times.pop(module_id)
# durations.append(duration)
# if durations:
# current_batch_stats["module_inference_times"][module_name] = {
# "total": sum(durations),
# "count": len(durations),
# "avg": sum(durations) / len(durations)
# }
# else:
# current_batch_stats["module_inference_times"][module_name] = {
# "total": 0.0,
# "count": 0,
# "avg": 0.0
# }
# batch_stats_list.append(current_batch_stats) # Append the stats for this batch
# # --- Print Inference Timing and Token Statistics per Batch (Optional, for debugging) ---
# print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---")
# print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds")
# print_rank("--- Module Inference Timing Statistics ---")
# for module_name, stats in current_batch_stats["module_inference_times"].items():
# print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s")
# print_rank("--- Token Count Statistics ---")
# print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}")
# print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}")
# print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}")
# print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}")
# if is_late_interaction and reps.dim() == 3:
# local_max_len = max(local_max_len, reps.shape[1])
# local_embeds.append(reps)
# if not local_embeds:
# # Handle cases where a rank gets no data
# return np.array([]), [], [] # Return empty list for batch_stats_list as well
# # === DDP Synchronization and Padding for Late-Interaction Models ===
# if is_late_interaction:
# if dist.is_initialized():
# # 1. Find the global maximum sequence length across all ranks
# local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device)
# dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX)
# global_max_len = local_max_len_tensor.item()
# else:
# global_max_len = local_max_len
# # 2. Pad all local embeddings to the global max length
# padded_embeds = []
# for reps_batch in local_embeds:
# if reps_batch.dim() == 3:
# B, L, H = reps_batch.shape
# padding_size = global_max_len - L
# padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0)
# padded_embeds.append(padded_batch)
# else: # Should not happen if model is consistently late-interaction
# padded_embeds.append(reps_batch)
# embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous()
# else: # Standard dense models
# embeds_tensor = torch.cat(local_embeds, dim=0).contiguous()
# # === Gather embeddings and keys from all ranks ===
# if dist.is_initialized() and full_dataset.num_rows >= world_size:
# print_master(f"Gathering {encode_side} embeddings across all ranks...")
# # Use the more efficient all_gather_into_tensor for tensors
# output_shape = list(embeds_tensor.shape)
# output_shape[0] = full_dataset.num_rows
# embeds_tensor = embeds_tensor.to(training_args.device)
# gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device)
# dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor)
# final_embeddings = gathered_embeds_tensor.cpu().float().numpy()
# # Gather metadata, for which all_gather_object is appropriate
# gathered_gt_infos = [None for _ in range(world_size)]
# dist.all_gather_object(gathered_gt_infos, local_gt_infos)
# all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys]
# # --- New: Gather batch_stats_list from all ranks ---
# gathered_batch_stats = [None for _ in range(world_size)]
# dist.all_gather_object(gathered_batch_stats, batch_stats_list)
# all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats]
# else:
# all_gt_infos = local_gt_infos
# final_embeddings = embeds_tensor.cpu().float().numpy()
# all_batch_stats = batch_stats_list # If not DDP, just use local list
# return final_embeddings, all_gt_infos, all_batch_stats
# def main():
# if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
# dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# world_size = dist.get_world_size() if dist.is_initialized() else 1
# # DEBUG PRINTS for Distributed Setup
# print_master("Distributed init debug info:")
# print_master(f"RANK: {os.environ.get('RANK')}")
# print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}")
# print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}")
# print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}")
# print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}")
# if dist.is_initialized():
# print_rank(f"dist.get_rank(): {dist.get_rank()}")
# print_rank(f"dist.get_world_size(): {dist.get_world_size()}")
# for arg in sys.argv:
# if arg.startswith("--local-rank="):
# rank = arg.split("=")[1]
# sys.argv.remove(arg)
# sys.argv.append('--local_rank')
# sys.argv.append(rank)
# parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
# model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# model_args: ModelArguments
# data_args: DataArguments
# training_args: TrainingArguments
# os.makedirs(data_args.encode_output_path, exist_ok=True)
# # --- Model Loading ---
# hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# if not getattr(model_args, "model_backbone", None):
# model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
# setattr(model_args, 'model_backbone', model_backbone)
# setattr(training_args, 'model_backbone', model_backbone)
# print_master(f'Model Backbone: {model_args.model_backbone}')
# # --- DDP-Safe Model Loading ---
# # Step 1: Only the master process (rank 0) downloads the model.
# if local_rank == 0:
# processor = load_processor(model_args, data_args)
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...")
# # Step 2: All processes wait here. The non-master processes will pause
# # until the master process (rank 0) finishes downloading and exits this barrier.
# if torch.distributed.is_initialized():
# torch.distributed.barrier()
# # Step 3: Now that the model is cached, the non-master processes load it from the local cache.
# if local_rank != 0:
# print_rank(f"Loading the model from cache...")
# processor = load_processor(model_args, data_args)
# time.sleep(random.randint(2 * local_rank, 3 * local_rank))
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# model.eval()
# model = model.to(training_args.device, dtype=torch.bfloat16)
# with open(data_args.dataset_config, 'r') as yaml_file:
# dataset_configs = yaml.safe_load(yaml_file)
# # --- Main Evaluation Loop ---
# for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()):
# # Initialize task-level statistics accumulators
# task_total_stats = {
# "total_inference_time_seconds": 0.0,
# "module_inference_times": {
# "Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0},
# "PatchMerger": {"total": 0.0, "count": 0},
# "Qwen2VLModel": {"total": 0.0, "count": 0},
# "Linear": {"total": 0.0, "count": 0},
# },
# "token_counts": {
# "visual_tokens": 0,
# "language_input_tokens_raw": 0,
# "llm_total_input_tokens": 0,
# "language_output_tokens": 0,
# },
# "data_point_count": 0 # Number of image-text pairs processed
# }
# if dist.is_initialized():
# dist.barrier()
# print_master(f"\n--- Evaluating {dataset_name} ---")
# query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry")
# cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt")
# dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl")
# # New: Define path for inference statistics output
# inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_inference_stats.json")
# do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path)
# do_cand = not os.path.exists(cand_embed_path)
# if do_query or do_cand:
# if data_args.data_basedir is not None:
# # Construct full paths for data files if --data_basedir is provided
# for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
# if data_args.data_basedir and task_config.get(key):
# task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
# full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
# full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
# eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# # Pad datasets to be divisible by world_size before splitting
# if dist.is_initialized():
# padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size)
# padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size)
# eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size)
# eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size)
# else:
# padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# # --- 1. Compute Query Embeddings ---
# if do_query:
# print_master("Encoding queries...")
# eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
# eval_qry_loader = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers)
# # Modified: capture batch_stats_list
# query_embeds, gt_infos, qry_batch_stats = encode_embeddings(model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries for {dataset_name}")
# # Accumulate query statistics
# for batch_stat in qry_batch_stats:
# batch_size = batch_stat["batch_size"]
# task_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"]
# for module_name, module_stats in batch_stat["module_inference_times"].items():
# if module_name in task_total_stats["module_inference_times"]:
# task_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"]
# task_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"] # count here is per-module-call, not per-item
# # Token counts are per-item for 'visual_tokens', 'llm_total_input_tokens'.
# # For 'text_input_tokens', it's calculated based on sequence length, so it's also total tokens in the batch.
# # We need to average it later by total data_point_count.
# # However, your current hook logic collects the token count for a *single* item if batch_size=1,
# # or for the full batch if it processes sequentially.
# # Let's assume the `token_info` collected by hooks reflects the *current batch*.
# # To get per-item average later, we sum up totals and divide by total data points.
# task_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size # Corrected assumption: visual_tokens are per-item, multiplied by batch_size to get total for batch
# task_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size # Corrected
# task_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size # Corrected
# task_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size # Corrected
# task_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items
# query_embeds = query_embeds[:len(full_eval_qry_dataset)]
# gt_infos = gt_infos[:len(full_eval_qry_dataset)]
# if local_rank == 0:
# with open(query_embed_path, 'wb') as f:
# pickle.dump(query_embeds, f)
# with open(dataset_info_path, 'w') as f:
# for info in gt_infos:
# f.write(json.dumps(info) + '\n')
# print_master(f"Saved query embeddings to {query_embed_path}")
# if dist.is_initialized():
# dist.barrier()
# # --- 2. Compute Candidate Embeddings ---
# if do_cand:
# print_master("Encoding candidates...")
# eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
# eval_cand_loader = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers)
# # Modified: capture batch_stats_list
# cand_embeds, all_cand_ids, cand_batch_stats = encode_embeddings(model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates for {dataset_name}")
# # Accumulate candidate statistics (similar logic as query)
# for batch_stat in cand_batch_stats:
# batch_size = batch_stat["batch_size"]
# task_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"]
# for module_name, module_stats in batch_stat["module_inference_times"].items():
# if module_name in task_total_stats["module_inference_times"]:
# task_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"]
# task_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"]
# task_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size
# task_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size
# task_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size
# task_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size
# task_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items
# cand_embeds = cand_embeds[:len(full_eval_cand_dataset)]
# all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)]
# if local_rank == 0:
# cand_embed_dict = {cand_id: embed for cand_id, embed in zip(all_cand_ids, cand_embeds)}
# with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f)
# print_master(f"Saved candidate embeddings to {cand_embed_path}")
# if dist.is_initialized():
# dist.barrier()
# # --- New: Calculate and Save Task-level Inference Statistics (on master rank only) ---
# if local_rank == 0:
# if task_total_stats["data_point_count"] > 0:
# final_task_stats = {
# "task_name": dataset_name,
# "data_point_count": task_total_stats["data_point_count"],
# "inference_times": {
# "total_inference_time_seconds": task_total_stats["total_inference_time_seconds"],
# "avg_inference_time_per_item_seconds": task_total_stats["total_inference_time_seconds"] / task_total_stats["data_point_count"],
# "module_average_times_per_call": {}, # Average per call to the module
# "module_total_times_seconds": {}, # Total time spent in the module
# "module_calls_count": {}, # Number of times the module was called
# },
# "token_counts": {
# "total_visual_tokens": task_total_stats["token_counts"]["visual_tokens"],
# "avg_visual_tokens_per_item": task_total_stats["token_counts"]["visual_tokens"] / task_total_stats["data_point_count"],
# "total_language_input_tokens_raw": task_total_stats["token_counts"]["language_input_tokens_raw"],
# "avg_language_input_tokens_raw_per_item": task_total_stats["token_counts"]["language_input_tokens_raw"] / task_total_stats["data_point_count"],
# "total_llm_total_input_tokens": task_total_stats["token_counts"]["llm_total_input_tokens"],
# "avg_llm_total_input_tokens_per_item": task_total_stats["token_counts"]["llm_total_input_tokens"] / task_total_stats["data_point_count"],
# "total_language_output_tokens": task_total_stats["token_counts"]["language_output_tokens"],
# "avg_language_output_tokens_per_item": task_total_stats["token_counts"]["language_output_tokens"] / task_total_stats["data_point_count"],
# }
# }
# for module_name, stats in task_total_stats["module_inference_times"].items():
# final_task_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"]
# final_task_stats["inference_times"]["module_calls_count"][module_name] = stats["count"]
# if stats["count"] > 0:
# final_task_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"]
# else:
# final_task_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0
# with open(inference_stats_path, 'w', encoding='utf-8') as f:
# json.dump(final_task_stats, f, ensure_ascii=False, indent=4)
# print_master(f"Inference statistics for {dataset_name} saved to: {inference_stats_path}")
# else:
# print_master(f"No data processed for {dataset_name}, skipping inference statistics output.")
# # --- 3. Compute Scores (on master rank only) ---
# score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json")
# if os.path.exists(score_path):
# try:
# with open(score_path, "r") as f:
# score_dict = json.load(f)
# print_master(f"Score of {dataset_name} (loaded from previous run): {score_path}")
# formatted = {k: f"{v:.4f}" for k, v in score_dict.items()}
# print_master(formatted)
# # No `continue` here, as we want to ensure other files are processed/generated
# except Exception as e:
# print_master(f"Failed to load score for {dataset_name}, proceeding to recompute. Error: {e}")
# # Proceed with score computation if not loaded or failed to load
# with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f)
# with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f)
# gt_infos = [json.loads(l) for l in open(dataset_info_path)]
# pred_dicts = []
# rank_against_all_candidates = task_config.get("eval_type", "global") == "global"
# if rank_against_all_candidates:
# cand_keys = list(cand_embed_dict.keys())
# cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys])
# # Handle late-interaction scoring
# if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H]
# qry_embed = torch.from_numpy(qry_embeds)
# cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds]
# scores = processor.score(qry_embed, cand_embeds, batch_size=64) # use ColPali score function
# ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()
# else: # Dense
# cosine_scores = np.dot(qry_embeds, cand_embeds.T)
# ranked_candids = np.argsort(-cosine_scores, axis=1)
# for qid, (ranked_candid, gt_info) in tqdm(enumerate(zip(ranked_candids, gt_infos)), desc=f"Calculating scores for {dataset_name}"):
# rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
# rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None
# assert rel_scores is None or len(rel_docids) == len(rel_scores)
# pred_dicts.append({
# "prediction": [cand_keys[i] for i in ranked_candid],
# "label": rel_docids,
# "rel_scores": rel_scores,
# })
# else:
# for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"):
# cand_embeds = np.stack([cand_embed_dict[key] for key in gt_info["cand_names"]])
# if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H]
# qry_embed = torch.from_numpy(np.array(qry_embed)).unsqueeze(0)
# cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds]
# scores = processor.score(qry_embed, cand_embeds, batch_size=1024) # use ColPali score function
# ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()[0]
# else:
# cosine_score = np.dot(qry_embed, cand_embeds.T)
# ranked_candids = np.argsort(-cosine_score)
# rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
# rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None
# assert rel_scores is None or len(rel_docids) == len(rel_scores)
# pred_dicts.append({
# "prediction": [gt_info["cand_names"][i] for i in ranked_candids],
# "label": rel_docids,
# "rel_scores": rel_scores,
# })
# score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json")
# pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl")
# metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
# metrics = RankingMetrics(metrics_to_report)
# score_dict = metrics.evaluate(pred_dicts)
# formatted = {k: f"{v:.4f}" for k, v in score_dict.items()}
# score_dict["num_pred"] = len(pred_dicts)
# score_dict["num_data"] = len(gt_infos)
# print_master(f"Score of {dataset_name}:")
# print_master(formatted)
# print_master(f"Outputting final score to: {score_path}")
# with open(score_path, "w") as f:
# json.dump(score_dict, f, indent=4)
# with open(pred_path, "w") as f:
# for pred in pred_dicts:
# f.write(json.dumps(pred) + '\n')
# if __name__ == "__main__":
# main()
##########################################################################################
# ########################################################################################
# #分query和cand进行统计
# import datetime
# import logging
# import json
# import random
# import time
# import numpy as np
# import os
# import pickle
# import sys
# import torch
# import torch.distributed as dist
# import torch.nn.functional as F
# import yaml
# import transformers
# from torch.utils.data import DataLoader
# from tqdm import tqdm
# from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
# from datasets import Dataset, concatenate_datasets
# from datasets.distributed import split_dataset_by_node
# from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src
# from src.arguments import ModelArguments, DataArguments, TrainingArguments
# from src.data.collator.eval_collator import MultimodalEvalDataCollator
# from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
# from src.eval_utils.metrics import RankingMetrics
# from src.model.model import MMEBModel
# from src.model.processor import get_backbone_name, load_processor, COLPALI
# from src.utils import batch_to_device, print_rank, print_master
# logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
# logger = logging.getLogger(__name__)
# # --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) ---
# timing_info = {}
# token_info = {
# "vision_tokens": 0,
# "text_input_tokens": 0, # Refers to the original text token count
# "text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0.
# "total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text)
# }
# # --- Hook Functions Definition ---
# def timing_pre_hook(module, input):
# module_id = id(module)
# if module_id not in timing_info:
# timing_info[module_id] = []
# timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__))
# def timing_post_hook(module, input, output):
# module_id = id(module)
# if module_id not in timing_info:
# # print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})")
# return
# timing_info[module_id].append((time.time(), 'post', module.__class__.__name__))
# # Collect vision token count (only from Vision Transformer module's post hook)
# module_name = module.__class__.__name__
# if "vision" in module_name.lower() and "transformer" in module_name.lower():
# if isinstance(output, torch.Tensor):
# token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim)
# elif hasattr(output, 'last_hidden_state'):
# token_info["vision_tokens"] = output.last_hidden_state.shape[1]
# def register_model_hooks(model):
# registered_modules = []
# core_model = model
# # print_master(f"DEBUG: Initial model type in register_model_hooks: {type(model)}")
# if hasattr(model, 'encoder') and model.encoder is not None:
# # print_master(f"DEBUG: model has 'encoder' attribute. Type of model.encoder: {type(model.encoder)}")
# # 使用从 'src' 路径导入的 Qwen2VLForConditionalGeneration 进行检查
# if isinstance(model.encoder, _Qwen2VLForConditionalGeneration_src):
# # print_master("Detected MMEBModel structure, registering hooks on model.encoder's sub-modules.")
# core_model = model.encoder
# else:
# print_master(f"WARNING: model.encoder is not an instance of _Qwen2VLForConditionalGeneration_src. Its type is {type(model.encoder)}. Hooks will be registered on top-level model if applicable.")
# else:
# print_master("WARNING: Model structure does not have an 'encoder' attribute. Registering hooks directly on top-level modules.")
# # Vision module
# if hasattr(core_model, 'visual') and core_model.visual is not None:
# vision_module = core_model.visual
# vision_module.register_forward_pre_hook(timing_pre_hook)
# vision_module.register_forward_hook(timing_post_hook)
# registered_modules.append(vision_module)
# print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).")
# # Merger module (if inside visual) - it's part of the vision component
# if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None:
# merger_module = core_model.visual.merger
# merger_module.register_forward_pre_hook(timing_pre_hook)
# merger_module.register_forward_hook(timing_post_hook)
# registered_modules.append(merger_module)
# print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).")
# # Language model body
# if hasattr(core_model, 'model') and core_model.model is not None:
# llm_main_module = core_model.model
# llm_main_module.register_forward_pre_hook(timing_pre_hook)
# llm_main_module.register_forward_hook(timing_post_hook)
# registered_modules.append(llm_main_module)
# print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).")
# # LM Head
# if hasattr(core_model, 'lm_head') and core_model.lm_head is not None:
# lm_head_module = core_model.lm_head
# lm_head_module.register_forward_pre_hook(timing_pre_hook)
# lm_head_module.register_forward_hook(timing_post_hook)
# registered_modules.append(lm_head_module)
# print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).")
# if not registered_modules:
# print_master("Warning: No major modules found for hook registration. Check model architecture.")
# return registered_modules
# def pad_dataset_to_divisible(dataset, world_size):
# num_samples = len(dataset)
# if num_samples % world_size == 0:
# return dataset, num_samples
# num_to_add = world_size - (num_samples % world_size)
# padded_size = num_samples + num_to_add
# padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
# padded_dataset = concatenate_datasets([dataset, padding_data])
# return padded_dataset, padded_size
# def encode_embeddings(
# model: MMEBModel,
# loader: DataLoader,
# training_args: TrainingArguments,
# model_args: ModelArguments,
# full_dataset: Dataset,
# encode_side: str,
# description: str = "Encoding"
# ) -> tuple[np.ndarray, list, list]: # Added list to return type for batch_stats
# """
# Encodes embeddings for a given dataset using the model, handling both standard and
# late-interaction models in a DDP-safe manner.
# """
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# world_size = dist.get_world_size() if dist.is_initialized() else 1
# # Check if the model is a late-interaction type
# is_late_interaction = (model_args.model_backbone == COLPALI)
# local_embeds = []
# local_gt_infos = []
# local_max_len = 0
# # --- New: List to store statistics for each batch ---
# batch_stats_list = []
# model.eval()
# # Register hooks for the model once per encode_embeddings call
# registered_hooks = register_model_hooks(model)
# with torch.no_grad():
# for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0):
# # --- Reset statistics for each inference pass ---
# timing_info.clear()
# token_info["vision_tokens"] = 0
# token_info["text_input_tokens"] = 0
# token_info["text_output_tokens"] = 0
# token_info["total_llm_input_tokens"] = 0
# inputs = batch_to_device(inputs, training_args.device)
# current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else 1 # Determine actual batch size
# with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
# start_inference_time = time.time()
# if encode_side == "qry":
# output = model(qry=inputs)
# reps = output["qry_reps"].detach()
# local_gt_infos.extend(dataset_info)
# else:
# output = model(tgt=inputs)
# reps = output["tgt_reps"].detach()
# local_gt_infos.extend([info["cand_name"] for info in dataset_info])
# end_inference_time = time.time()
# # --- Update total LLM input tokens after the model call ---
# if 'input_ids' in inputs and inputs['input_ids'] is not None:
# # `inputs['input_ids'].shape[1]` gives the sequence length,
# # which is the number of tokens per item in the batch.
# # To get total tokens for the batch, multiply by batch size.
# token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1]
# # Approximation for text_input_tokens
# token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"]
# token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"]) # Ensure not negative
# # --- Collect and Store Batch Statistics ---
# batch_inference_time = end_inference_time - start_inference_time
# current_batch_stats = {
# "batch_size": current_batch_size,
# "total_inference_time_seconds": batch_inference_time,
# "module_inference_times": {},
# "token_counts": {
# "visual_tokens": token_info["vision_tokens"],
# "language_input_tokens_raw": token_info["text_input_tokens"],
# "llm_total_input_tokens": token_info["total_llm_input_tokens"],
# "language_output_tokens": token_info["text_output_tokens"],
# }
# }
# # Calculate and store module timings for the current batch
# for module_obj in registered_hooks:
# module_id = id(module_obj)
# module_name = module_obj.__class__.__name__
# times = timing_info.get(module_id, [])
# durations = []
# pre_times = {}
# for t, event_type, _ in times:
# if event_type == 'pre':
# pre_times[module_id] = t
# elif event_type == 'post' and module_id in pre_times:
# duration = t - pre_times.pop(module_id)
# durations.append(duration)
# if durations:
# current_batch_stats["module_inference_times"][module_name] = {
# "total": sum(durations),
# "count": len(durations),
# "avg": sum(durations) / len(durations)
# }
# else:
# current_batch_stats["module_inference_times"][module_name] = {
# "total": 0.0,
# "count": 0,
# "avg": 0.0
# }
# batch_stats_list.append(current_batch_stats) # Append the stats for this batch
# # --- Print Inference Timing and Token Statistics per Batch (Optional, for debugging) ---
# print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---")
# print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds")
# print_rank("--- Module Inference Timing Statistics ---")
# for module_name, stats in current_batch_stats["module_inference_times"].items():
# print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s")
# print_rank("--- Token Count Statistics ---")
# print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}")
# print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}")
# print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}")
# print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}")
# if is_late_interaction and reps.dim() == 3:
# local_max_len = max(local_max_len, reps.shape[1])
# local_embeds.append(reps)
# if not local_embeds:
# # Handle cases where a rank gets no data
# return np.array([]), [], [] # Return empty list for batch_stats_list as well
# # === DDP Synchronization and Padding for Late-Interaction Models ===
# if is_late_interaction:
# if dist.is_initialized():
# # 1. Find the global maximum sequence length across all ranks
# local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device)
# dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX)
# global_max_len = local_max_len_tensor.item()
# else:
# global_max_len = local_max_len
# # 2. Pad all local embeddings to the global max length
# padded_embeds = []
# for reps_batch in local_embeds:
# if reps_batch.dim() == 3:
# B, L, H = reps_batch.shape
# padding_size = global_max_len - L
# padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0)
# padded_embeds.append(padded_batch)
# else: # Should not happen if model is consistently late-interaction
# padded_embeds.append(reps_batch)
# embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous()
# else: # Standard dense models
# embeds_tensor = torch.cat(local_embeds, dim=0).contiguous()
# # === Gather embeddings and keys from all ranks ===
# if dist.is_initialized() and full_dataset.num_rows >= world_size:
# print_master(f"Gathering {encode_side} embeddings across all ranks...")
# # Use the more efficient all_gather_into_tensor for tensors
# output_shape = list(embeds_tensor.shape)
# output_shape[0] = full_dataset.num_rows
# embeds_tensor = embeds_tensor.to(training_args.device)
# gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device)
# dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor)
# final_embeddings = gathered_embeds_tensor.cpu().float().numpy()
# # Gather metadata, for which all_gather_object is appropriate
# gathered_gt_infos = [None for _ in range(world_size)]
# dist.all_gather_object(gathered_gt_infos, local_gt_infos)
# all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys]
# # --- New: Gather batch_stats_list from all ranks ---
# gathered_batch_stats = [None for _ in range(world_size)]
# dist.all_gather_object(gathered_batch_stats, batch_stats_list)
# all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats]
# else:
# all_gt_infos = local_gt_infos
# final_embeddings = embeds_tensor.cpu().float().numpy()
# all_batch_stats = batch_stats_list # If not DDP, just use local list
# return final_embeddings, all_gt_infos, all_batch_stats
# def main():
# if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
# dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# world_size = dist.get_world_size() if dist.is_initialized() else 1
# # DEBUG PRINTS for Distributed Setup
# print_master("Distributed init debug info:")
# print_master(f"RANK: {os.environ.get('RANK')}")
# print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}")
# print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}")
# print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}")
# print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}")
# if dist.is_initialized():
# print_rank(f"dist.get_rank(): {dist.get_rank()}")
# print_rank(f"dist.get_world_size(): {dist.get_world_size()}")
# for arg in sys.argv:
# if arg.startswith("--local-rank="):
# rank = arg.split("=")[1]
# sys.argv.remove(arg)
# sys.argv.append('--local_rank')
# sys.argv.append(rank)
# parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
# model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# model_args: ModelArguments
# data_args: DataArguments
# training_args: TrainingArguments
# os.makedirs(data_args.encode_output_path, exist_ok=True)
# # --- Model Loading ---
# hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# if not getattr(model_args, "model_backbone", None):
# model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
# setattr(model_args, 'model_backbone', model_backbone)
# setattr(training_args, 'model_backbone', model_backbone)
# print_master(f'Model Backbone: {model_args.model_backbone}')
# # --- DDP-Safe Model Loading ---
# # Step 1: Only the master process (rank 0) downloads the model.
# if local_rank == 0:
# processor = load_processor(model_args, data_args)
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...")
# # Step 2: All processes wait here. The non-master processes will pause
# # until the master process (rank 0) finishes downloading and exits this barrier.
# if torch.distributed.is_initialized():
# torch.distributed.barrier()
# # Step 3: Now that the model is cached, the non-master processes load it from the local cache.
# if local_rank != 0:
# print_rank(f"Loading the model from cache...")
# processor = load_processor(model_args, data_args)
# time.sleep(random.randint(2 * local_rank, 3 * local_rank))
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# model.eval()
# model = model.to(training_args.device, dtype=torch.bfloat16)
# with open(data_args.dataset_config, 'r') as yaml_file:
# dataset_configs = yaml.safe_load(yaml_file)
# # --- Main Evaluation Loop ---
# for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()):
# # Initialize task-level statistics accumulators for QUERY
# query_total_stats = {
# "total_inference_time_seconds": 0.0,
# "module_inference_times": {
# "Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0},
# "PatchMerger": {"total": 0.0, "count": 0},
# "Qwen2VLModel": {"total": 0.0, "count": 0},
# "Linear": {"total": 0.0, "count": 0},
# },
# "token_counts": {
# "visual_tokens": 0,
# "language_input_tokens_raw": 0,
# "llm_total_input_tokens": 0,
# "language_output_tokens": 0,
# },
# "data_point_count": 0 # Number of image-text pairs processed
# }
# # Initialize task-level statistics accumulators for CANDIDATE
# cand_total_stats = {
# "total_inference_time_seconds": 0.0,
# "module_inference_times": {
# "Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0},
# "PatchMerger": {"total": 0.0, "count": 0},
# "Qwen2VLModel": {"total": 0.0, "count": 0},
# "Linear": {"total": 0.0, "count": 0},
# },
# "token_counts": {
# "visual_tokens": 0,
# "language_input_tokens_raw": 0,
# "llm_total_input_tokens": 0,
# "language_output_tokens": 0,
# },
# "data_point_count": 0 # Number of image-text pairs processed
# }
# if dist.is_initialized():
# dist.barrier()
# print_master(f"\n--- Evaluating {dataset_name} ---")
# query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry")
# cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt")
# dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl")
# # New: Define distinct paths for query and candidate inference statistics output
# query_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_inference_stats.json")
# cand_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_inference_stats.json")
# do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path)
# do_cand = not os.path.exists(cand_embed_path)
# if do_query or do_cand:
# if data_args.data_basedir is not None:
# # Construct full paths for data files if --data_basedir is provided
# for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
# if data_args.data_basedir and task_config.get(key):
# task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
# full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
# full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
# eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# # Pad datasets to be divisible by world_size before splitting
# if dist.is_initialized():
# padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size)
# padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size)
# eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size)
# eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size)
# else:
# padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# # --- 1. Compute Query Embeddings ---
# if do_query:
# print_master("Encoding queries...")
# eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
# eval_qry_loader = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers)
# # Modified: capture batch_stats_list
# query_embeds, gt_infos, qry_batch_stats = encode_embeddings(model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries for {dataset_name}")
# # Accumulate query statistics
# for batch_stat in qry_batch_stats:
# batch_size = batch_stat["batch_size"]
# query_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"]
# for module_name, module_stats in batch_stat["module_inference_times"].items():
# if module_name in query_total_stats["module_inference_times"]:
# query_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"]
# query_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"]
# query_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size
# query_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size
# query_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size
# query_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size
# query_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items
# query_embeds = query_embeds[:len(full_eval_qry_dataset)]
# gt_infos = gt_infos[:len(full_eval_qry_dataset)]
# if local_rank == 0:
# with open(query_embed_path, 'wb') as f:
# pickle.dump(query_embeds, f)
# with open(dataset_info_path, 'w') as f:
# for info in gt_infos:
# f.write(json.dumps(info) + '\n')
# print_master(f"Saved query embeddings to {query_embed_path}")
# # Save query-specific inference statistics
# if query_total_stats["data_point_count"] > 0:
# final_query_stats = {
# "task_name": dataset_name,
# "encode_side": "query",
# "data_point_count": query_total_stats["data_point_count"],
# "inference_times": {
# "total_inference_time_seconds": query_total_stats["total_inference_time_seconds"],
# "avg_inference_time_per_item_seconds": query_total_stats["total_inference_time_seconds"] / query_total_stats["data_point_count"],
# "module_average_times_per_call": {},
# "module_total_times_seconds": {},
# "module_calls_count": {},
# },
# "token_counts": {
# "total_visual_tokens": query_total_stats["token_counts"]["visual_tokens"],
# "avg_visual_tokens_per_item": query_total_stats["token_counts"]["visual_tokens"] / query_total_stats["data_point_count"],
# "total_language_input_tokens_raw": query_total_stats["token_counts"]["language_input_tokens_raw"],
# "avg_language_input_tokens_raw_per_item": query_total_stats["token_counts"]["language_input_tokens_raw"] / query_total_stats["data_point_count"],
# "total_llm_total_input_tokens": query_total_stats["token_counts"]["llm_total_input_tokens"],
# "avg_llm_total_input_tokens_per_item": query_total_stats["token_counts"]["llm_total_input_tokens"] / query_total_stats["data_point_count"],
# "total_language_output_tokens": query_total_stats["token_counts"]["language_output_tokens"],
# "avg_language_output_tokens_per_item": query_total_stats["token_counts"]["language_output_tokens"] / query_total_stats["data_point_count"],
# }
# }
# for module_name, stats in query_total_stats["module_inference_times"].items():
# final_query_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"]
# final_query_stats["inference_times"]["module_calls_count"][module_name] = stats["count"]
# if stats["count"] > 0:
# final_query_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"]
# else:
# final_query_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0
# with open(query_inference_stats_path, 'w', encoding='utf-8') as f:
# json.dump(final_query_stats, f, ensure_ascii=False, indent=4)
# print_master(f"Query inference statistics for {dataset_name} saved to: {query_inference_stats_path}")
# else:
# print_master(f"No query data processed for {dataset_name}, skipping query inference statistics output.")
# if dist.is_initialized():
# dist.barrier()
# # --- 2. Compute Candidate Embeddings ---
# if do_cand:
# print_master("Encoding candidates...")
# eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
# eval_cand_loader = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers)
# # Modified: capture batch_stats_list
# cand_embeds, all_cand_ids, cand_batch_stats = encode_embeddings(model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates for {dataset_name}")
# # Accumulate candidate statistics (similar logic as query)
# for batch_stat in cand_batch_stats:
# batch_size = batch_stat["batch_size"]
# cand_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"]
# for module_name, module_stats in batch_stat["module_inference_times"].items():
# if module_name in cand_total_stats["module_inference_times"]:
# cand_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"]
# cand_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"]
# cand_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size
# cand_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size
# cand_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size
# cand_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size
# cand_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items
# cand_embeds = cand_embeds[:len(full_eval_cand_dataset)]
# all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)]
# if local_rank == 0:
# cand_embed_dict = {cand_id: embed for cand_id, embed in zip(all_cand_ids, cand_embeds)}
# with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f)
# print_master(f"Saved candidate embeddings to {cand_embed_path}")
# # Save candidate-specific inference statistics
# if cand_total_stats["data_point_count"] > 0:
# final_cand_stats = {
# "task_name": dataset_name,
# "encode_side": "candidate",
# "data_point_count": cand_total_stats["data_point_count"],
# "inference_times": {
# "total_inference_time_seconds": cand_total_stats["total_inference_time_seconds"],
# "avg_inference_time_per_item_seconds": cand_total_stats["total_inference_time_seconds"] / cand_total_stats["data_point_count"],
# "module_average_times_per_call": {},
# "module_total_times_seconds": {},
# "module_calls_count": {},
# },
# "token_counts": {
# "total_visual_tokens": cand_total_stats["token_counts"]["visual_tokens"],
# "avg_visual_tokens_per_item": cand_total_stats["token_counts"]["visual_tokens"] / cand_total_stats["data_point_count"],
# "total_language_input_tokens_raw": cand_total_stats["token_counts"]["language_input_tokens_raw"],
# "avg_language_input_tokens_raw_per_item": cand_total_stats["token_counts"]["language_input_tokens_raw"] / cand_total_stats["data_point_count"],
# "total_llm_total_input_tokens": cand_total_stats["token_counts"]["llm_total_input_tokens"],
# "avg_llm_total_input_tokens_per_item": cand_total_stats["token_counts"]["llm_total_input_tokens"] / cand_total_stats["data_point_count"],
# "total_language_output_tokens": cand_total_stats["token_counts"]["language_output_tokens"],
# "avg_language_output_tokens_per_item": cand_total_stats["token_counts"]["language_output_tokens"] / cand_total_stats["data_point_count"],
# }
# }
# for module_name, stats in cand_total_stats["module_inference_times"].items():
# final_cand_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"]
# final_cand_stats["inference_times"]["module_calls_count"][module_name] = stats["count"]
# if stats["count"] > 0:
# final_cand_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"]
# else:
# final_cand_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0
# with open(cand_inference_stats_path, 'w', encoding='utf-8') as f:
# json.dump(final_cand_stats, f, ensure_ascii=False, indent=4)
# print_master(f"Candidate inference statistics for {dataset_name} saved to: {cand_inference_stats_path}")
# else:
# print_master(f"No candidate data processed for {dataset_name}, skipping candidate inference statistics output.")
# if dist.is_initialized():
# dist.barrier()
# # --- 3. Compute Scores (on master rank only) ---
# score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json")
# ####################################################################################
# pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl")
# score_detail_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_details.jsonl") # 新文件,存相似度分数
# def append_score_detail(score_detail_list, qid, ranked_indices, score_vector, cand_ids, labels):
# """追加一个 query 的候选分数详情"""
# score_detail_list.append({
# "qid": int(qid),
# "cand_scores": [
# {"cand_id": str(cand_ids[i]), "score": float(score_vector[i])}
# for i in ranked_indices
# ],
# "label": labels
# })
# ####################################################################################
# if local_rank == 0:
# if os.path.exists(score_path):
# try:
# with open(score_path, "r") as f:
# score_dict = json.load(f)
# print_master(f"Score of {dataset_name} (loaded from previous run): {score_path}")
# formatted = {k: f"{v:.4f}" for k, v in score_dict.items()}
# print_master(formatted)
# # No `continue` here, as we want to ensure other files are processed/generated
# except Exception as e:
# print_master(f"Failed to load score for {dataset_name}, proceeding to recompute. Error: {e}")
# # Proceed with score computation if not loaded or failed to load
# with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f)
# with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f)
# gt_infos = [json.loads(l) for l in open(dataset_info_path)]
# pred_dicts = []
# score_detail_dicts = []###################################
# rank_against_all_candidates = task_config.get("eval_type", "global") == "global"
# # if rank_against_all_candidates:
# # cand_keys = list(cand_embed_dict.keys())
# # cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys])
# # # Handle late-interaction scoring
# # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H]
# # qry_embed = torch.from_numpy(qry_embeds)
# # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds]
# # scores = processor.score(qry_embed, cand_embeds, batch_size=64) # use ColPali score function
# # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()
# # scores = scores.cpu().numpy()
# # else: # Dense
# # cosine_scores = np.dot(qry_embeds, cand_embeds.T)
# # ranked_candids = np.argsort(-cosine_scores, axis=1)
# #####################################################
# if rank_against_all_candidates:
# cand_keys = list(cand_embed_dict.keys())
# cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys])
# if qry_embeds.ndim == 3: # Late-interaction
# qry_embed_t = torch.from_numpy(qry_embeds)
# cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds]
# sim_matrix = processor.score(qry_embed_t, cand_embeds_t, batch_size=64).cpu().numpy() # [N_q, N_c]
# else: # Dense
# sim_matrix = np.dot(qry_embeds, cand_embeds.T) # [N_q, N_c]
# ranked_all = np.argsort(-sim_matrix, axis=1)
# #########################################################
# for qid, (ranked_candid, gt_info) in tqdm(enumerate(zip(ranked_candids, gt_infos)), desc=f"Calculating scores for {dataset_name}"):
# rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
# rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None
# assert rel_scores is None or len(rel_docids) == len(rel_scores)
# pred_dicts.append({
# "prediction": [cand_keys[i] for i in ranked_candid],
# "label": rel_docids,
# "rel_scores": rel_scores,
# })
# ################################# 新增:详细相似度字典
# append_score_detail(score_detail_dicts, qid, ranked_indices, sim_matrix[qid], cand_keys, rel_docids)
# ########################################
# # else:
# # for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"):
# # cand_embeds = np.stack([cand_embed_dict[key] for key in gt_info["cand_names"]])
# # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H]
# # qry_embed = torch.from_numpy(np.array(qry_embed)).unsqueeze(0)
# # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds]
# # scores = processor.score(qry_embed, cand_embeds, batch_size=1024) # use ColPali score function
# # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()[0]
# # else:
# # cosine_score = np.dot(qry_embed, cand_embeds.T)
# # ranked_candids = np.argsort(-cosine_score)
# # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
# # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None
# # assert rel_scores is None or len(rel_docids) == len(rel_scores)
# # pred_dicts.append({
# # "prediction": [gt_info["cand_names"][i] for i in ranked_candids],
# # "label": rel_docids,
# # "rel_scores": rel_scores,
# # })
# #######################################################################
# else: # 非全局
# for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"):
# cand_ids_local = gt_info["cand_names"]
# cand_embeds = np.stack([cand_embed_dict[key] for key in cand_ids_local])
# if qry_embeds.ndim == 3: # Late-interaction
# qry_embed_t = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # [1, Lq, H]
# cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds]
# sim_vec = processor.score(qry_embed_t, cand_embeds_t, batch_size=1024).cpu().numpy()[0] # [N_c]
# else: # Dense
# sim_vec = np.dot(qry_embed, cand_embeds.T) # [N_c]
# ranked_indices = np.argsort(-sim_vec)
# rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
# rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None
# assert rel_scores is None or len(rel_docids) == len(rel_scores)
# pred_dicts.append({
# "prediction": [cand_ids_local[i] for i in ranked_indices],
# "label": rel_docids,
# "rel_scores": rel_scores,
# })
# # 新增:分数详情
# append_score_detail(score_detail_dicts, qid, ranked_indices, sim_vec, cand_ids_local, rel_docids)
# ########################################## 保存预测和分数
# with open(score_detail_path, "w") as f: # 新增
# for detail in score_detail_dicts:
# f.write(json.dumps(detail) + '\n')
# print_master(f"Detailed score file saved to: {score_detail_path}")
# metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
# metrics = RankingMetrics(metrics_to_report)
# score_dict = metrics.evaluate(pred_dicts)
# formatted = {k: f"{v:.4f}" for k, v in score_dict.items()}
# score_dict["num_pred"] = len(pred_dicts)
# score_dict["num_data"] = len(gt_infos)
# print_master(f"Score of {dataset_name}:")
# print_master(formatted)
# print_master(f"Outputting final score to: {score_path}")
# with open(score_path, "w") as f:
# json.dump(score_dict, f, indent=4)
# with open(pred_path, "w") as f:
# for pred in pred_dicts:
# f.write(json.dumps(pred) + '\n')
# ####################################################################
# score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json")
# pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl")
# metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
# metrics = RankingMetrics(metrics_to_report)
# score_dict = metrics.evaluate(pred_dicts)
# formatted = {k: f"{v:.4f}" for k, v in score_dict.items()}
# score_dict["num_pred"] = len(pred_dicts)
# score_dict["num_data"] = len(gt_infos)
# print_master(f"Score of {dataset_name}:")
# print_master(formatted)
# print_master(f"Outputting final score to: {score_path}")
# with open(score_path, "w") as f:
# json.dump(score_dict, f, indent=4)
# with open(pred_path, "w") as f:
# for pred in pred_dicts:
# f.write(json.dumps(pred) + '\n')
# if __name__ == '__main__':
# main()
######################################################################################################
#######################################################################################################
#为了可视化把mask值也输出
########################################################################################
#分query和cand进行统计
import datetime
import logging
import json
import random
import time
import numpy as np
import os
import pickle
import sys
import torch
import torch.distributed as dist
import torch.nn.functional as F
import yaml
import transformers
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
from datasets import Dataset, concatenate_datasets
from datasets.distributed import split_dataset_by_node
from src.arguments import ModelArguments, DataArguments, TrainingArguments
from src.data.collator.eval_collator import MultimodalEvalDataCollator
from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
from src.eval_utils.metrics import RankingMetrics
from src.model.model import MMEBModel
from src.model.processor import get_backbone_name, load_processor, COLPALI
from src.utils import batch_to_device, print_rank, print_master
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
logger = logging.getLogger(__name__)
# --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) ---
timing_info = {}
token_info = {
"vision_tokens": 0,
"text_input_tokens": 0, # Refers to the original text token count
"text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0.
"total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text)
}
# --- Hook Functions Definition ---
def timing_pre_hook(module, input):
module_id = id(module)
if module_id not in timing_info:
timing_info[module_id] = []
timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__))
def timing_post_hook(module, input, output):
module_id = id(module)
if module_id not in timing_info:
# print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})")
return
timing_info[module_id].append((time.time(), 'post', module.__class__.__name__))
# Collect vision token count (only from Vision Transformer module's post hook)
module_name = module.__class__.__name__
if "vision" in module_name.lower() and "transformer" in module_name.lower():
if isinstance(output, torch.Tensor):
token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim)
elif hasattr(output, 'last_hidden_state'):
token_info["vision_tokens"] = output.last_hidden_state.shape[1]
def register_model_hooks(model):
registered_modules = []
core_model = model
# print_master(f"DEBUG: Initial model type in register_model_hooks: {type(model)}")
if hasattr(model, 'encoder') and model.encoder is not None:
print_master(f"DEBUG: model has 'encoder' attribute. Type of model.encoder: {type(model.encoder)}")
else:
print_master("WARNING: Model structure does not have an 'encoder' attribute. Registering hooks directly on top-level modules.")
# Vision module
if hasattr(core_model, 'visual') and core_model.visual is not None:
vision_module = core_model.visual
vision_module.register_forward_pre_hook(timing_pre_hook)
vision_module.register_forward_hook(timing_post_hook)
registered_modules.append(vision_module)
print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).")
# Merger module (if inside visual) - it's part of the vision component
if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None:
merger_module = core_model.visual.merger
merger_module.register_forward_pre_hook(timing_pre_hook)
merger_module.register_forward_hook(timing_post_hook)
registered_modules.append(merger_module)
print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).")
# Language model body
if hasattr(core_model, 'model') and core_model.model is not None:
llm_main_module = core_model.model
llm_main_module.register_forward_pre_hook(timing_pre_hook)
llm_main_module.register_forward_hook(timing_post_hook)
registered_modules.append(llm_main_module)
print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).")
# LM Head
if hasattr(core_model, 'lm_head') and core_model.lm_head is not None:
lm_head_module = core_model.lm_head
lm_head_module.register_forward_pre_hook(timing_pre_hook)
lm_head_module.register_forward_hook(timing_post_hook)
registered_modules.append(lm_head_module)
print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).")
if not registered_modules:
print_master("Warning: No major modules found for hook registration. Check model architecture.")
return registered_modules
def pad_dataset_to_divisible(dataset, world_size):
num_samples = len(dataset)
if num_samples % world_size == 0:
return dataset, num_samples
num_to_add = world_size - (num_samples % world_size)
padded_size = num_samples + num_to_add
padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
padded_dataset = concatenate_datasets([dataset, padding_data])
return padded_dataset, padded_size
def encode_embeddings(
model: MMEBModel,
loader: DataLoader,
training_args: TrainingArguments,
model_args: ModelArguments,
full_dataset: Dataset,
encode_side: str,
description: str = "Encoding"
) -> tuple[np.ndarray, list, list, list]: # CHANGED: + list for img_token_masks
"""
Encodes embeddings for a given dataset using the model, handling both standard and
late-interaction models in a DDP-safe manner.
Returns:
- embeddings: np.ndarray
- infos_or_ids: list
- batch_stats_list: list
- img_token_masks: list[None | list[bool]] # NEW
"""
local_rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
# Check if the model is a late-interaction type
is_late_interaction = (model_args.model_backbone == COLPALI)
local_embeds = []
local_gt_infos = []
local_max_len = 0
# --- New: List to store statistics for each batch ---
batch_stats_list = []
# --- NEW: Collect image token masks locally ---
local_img_token_masks = [] # 每个样本一个元素:None 或 [bool, ...]
model.eval()
# Register hooks for the model once per encode_embeddings call
registered_hooks = register_model_hooks(model)
# --- NEW: helpers to取mask并序列化 ---
def _search_key(obj, key: str):
# 递归搜索 dict/list/tuple,找到指定 key
if isinstance(obj, dict):
if key in obj:
return obj[key]
for v in obj.values():
r = _search_key(v, key)
if r is not None:
return r
elif isinstance(obj, (list, tuple)):
for v in obj:
r = _search_key(v, key)
if r is not None:
return r
return None
def _to_serializable_mask_list(mask_list, batch_size: int):
# 将模型返回的 mask(list/tensor/ndarray/None)转成 [None | list[bool]] * B
if mask_list is None:
return [None] * batch_size
out = []
if isinstance(mask_list, (list, tuple)):
for m in mask_list:
if m is None:
out.append(None)
elif torch.is_tensor(m):
out.append(m.detach().cpu().tolist())
elif isinstance(m, np.ndarray):
out.append(m.tolist())
else:
# already python list/bool
out.append(m)
elif torch.is_tensor(mask_list):
# 若是 2D 张量(B, L),直接 tolist() -> list[list[bool/int]]
out = mask_list.detach().cpu().tolist()
elif isinstance(mask_list, np.ndarray):
out = mask_list.tolist()
else:
# 未知类型,保守返回 None 占位
out = [None] * batch_size
# 长度对齐 batch_size
if isinstance(out, list):
if len(out) < batch_size:
out = out + [None] * (batch_size - len(out))
elif len(out) > batch_size:
out = out[:batch_size]
return out
with torch.no_grad():
for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0):
# --- Reset statistics for each inference pass ---
timing_info.clear()
token_info["vision_tokens"] = 0
token_info["text_input_tokens"] = 0
token_info["text_output_tokens"] = 0
token_info["total_llm_input_tokens"] = 0
inputs = batch_to_device(inputs, training_args.device)
current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
start_inference_time = time.time()
if encode_side == "qry":
output = model(qry=inputs)
# torch.set_printoptions(threshold=10000)
# print('output:', output)
# exit()
reps = output["qry_reps"].detach()
local_gt_infos.extend(dataset_info)
else:
output = model(tgt=inputs)
reps = output["tgt_reps"].detach()
local_gt_infos.extend([info["cand_name"] for info in dataset_info])
end_inference_time = time.time()
# --- NEW: 提取并保存本 batch 的 image_token_bool_masks ---
# 期望 MMEBModel 的 output 中直接或间接包含 'image_token_bool_masks'
img_masks_raw = None
if isinstance(output, dict):
img_masks_raw = _search_key(output, "image_token_bool_masks")
# 可选:若你在 MMEBModel 上挂了属性,也可以尝试读取
if img_masks_raw is None and hasattr(model, "image_token_bool_masks"):
img_masks_raw = getattr(model, "image_token_bool_masks")
img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size)
local_img_token_masks.extend(img_masks_serializable)
# --- Update total LLM input tokens after the model call ---
if 'input_ids' in inputs and inputs['input_ids'] is not None:
token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1]
token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"]
token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"])
# --- Collect and Store Batch Statistics ---
batch_inference_time = end_inference_time - start_inference_time
current_batch_stats = {
"batch_size": current_batch_size,
"total_inference_time_seconds": batch_inference_time,
"module_inference_times": {},
"token_counts": {
"visual_tokens": token_info["vision_tokens"],
"language_input_tokens_raw": token_info["text_input_tokens"],
"llm_total_input_tokens": token_info["total_llm_input_tokens"],
"language_output_tokens": token_info["text_output_tokens"],
}
}
# Calculate and store module timings for the current batch
for module_obj in registered_hooks:
module_id = id(module_obj)
module_name = module_obj.__class__.__name__
times = timing_info.get(module_id, [])
durations = []
pre_times = {}
for t, event_type, _ in times:
if event_type == 'pre':
pre_times[module_id] = t
elif event_type == 'post' and module_id in pre_times:
duration = t - pre_times.pop(module_id)
durations.append(duration)
if durations:
current_batch_stats["module_inference_times"][module_name] = {
"total": sum(durations),
"count": len(durations),
"avg": sum(durations) / len(durations)
}
else:
current_batch_stats["module_inference_times"][module_name] = {
"total": 0.0,
"count": 0,
"avg": 0.0
}
batch_stats_list.append(current_batch_stats)
# --- Debug prints (optional) ---
print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---")
print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds")
print_rank("--- Module Inference Timing Statistics ---")
for module_name, stats in current_batch_stats["module_inference_times"].items():
print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s")
print_rank("--- Token Count Statistics ---")
print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}")
print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}")
print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}")
print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}")
if is_late_interaction and reps.dim() == 3:
local_max_len = max(local_max_len, reps.shape[1])
local_embeds.append(reps)
if not local_embeds:
# Handle cases where a rank gets no data
return np.array([]), [], [], [] # CHANGED: 4个返回值
# === DDP Synchronization and Padding for Late-Interaction Models ===
if is_late_interaction:
if dist.is_initialized():
# 1: global max length
local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device)
dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX)
global_max_len = local_max_len_tensor.item()
else:
global_max_len = local_max_len
# 2: pad to global max length
padded_embeds = []
for reps_batch in local_embeds:
if reps_batch.dim() == 3:
B, L, H = reps_batch.shape
padding_size = global_max_len - L
padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0)
padded_embeds.append(padded_batch)
else:
padded_embeds.append(reps_batch)
embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous()
else:
embeds_tensor = torch.cat(local_embeds, dim=0).contiguous()
# === Gather embeddings and keys from all ranks ===
if dist.is_initialized() and full_dataset.num_rows >= world_size:
print_master(f"Gathering {encode_side} embeddings across all ranks...")
# tensor gather
output_shape = list(embeds_tensor.shape)
output_shape[0] = full_dataset.num_rows
embeds_tensor = embeds_tensor.to(training_args.device)
gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device)
dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor)
final_embeddings = gathered_embeds_tensor.cpu().float().numpy()
# object gather for infos and stats
gathered_gt_infos = [None for _ in range(world_size)]
dist.all_gather_object(gathered_gt_infos, local_gt_infos)
all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys]
gathered_batch_stats = [None for _ in range(world_size)]
dist.all_gather_object(gathered_batch_stats, batch_stats_list)
all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats]
# --- NEW: gather masks ---
gathered_masks = [None for _ in range(world_size)]
dist.all_gather_object(gathered_masks, local_img_token_masks)
all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list]
else:
all_gt_infos = local_gt_infos
final_embeddings = embeds_tensor.cpu().float().numpy()
all_batch_stats = batch_stats_list
all_img_token_masks = local_img_token_masks # NEW
return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks # CHANGED
def main():
if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
local_rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
# DEBUG PRINTS for Distributed Setup
print_master("Distributed init debug info:")
print_master(f"RANK: {os.environ.get('RANK')}")
print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}")
print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}")
print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}")
print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}")
if dist.is_initialized():
print_rank(f"dist.get_rank(): {dist.get_rank()}")
print_rank(f"dist.get_world_size(): {dist.get_world_size()}")
for arg in sys.argv:
if arg.startswith("--local-rank="):
rank = arg.split("=")[1]
sys.argv.remove(arg)
sys.argv.append('--local_rank')
sys.argv.append(rank)
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
os.makedirs(data_args.encode_output_path, exist_ok=True)
# --- Model Loading ---
hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
if not getattr(model_args, "model_backbone", None):
model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
setattr(training_args, 'model_backbone', model_backbone)
print_master(f'Model Backbone: {model_args.model_backbone}')
# --- DDP-Safe Model Loading ---
# Step 1: Only the master process (rank 0) downloads the model.
if local_rank == 0:
processor = load_processor(model_args, data_args)
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...")
# Step 2: All processes wait here. The non-master processes will pause
# until the master process (rank 0) finishes downloading and exits this barrier.
if torch.distributed.is_initialized():
torch.distributed.barrier()
# Step 3: Now that the model is cached, the non-master processes load it from the local cache.
if local_rank != 0:
print_rank(f"Loading the model from cache...")
processor = load_processor(model_args, data_args)
time.sleep(random.randint(2 * local_rank, 3 * local_rank))
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
model.eval()
model = model.to(training_args.device, dtype=torch.bfloat16)
with open(data_args.dataset_config, 'r') as yaml_file:
dataset_configs = yaml.safe_load(yaml_file)
# --- Main Evaluation Loop ---
for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()):
# Initialize task-level statistics accumulators for QUERY
query_total_stats = {
"total_inference_time_seconds": 0.0,
"module_inference_times": {
"Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0},
"PatchMerger": {"total": 0.0, "count": 0},
"Qwen2VLModel": {"total": 0.0, "count": 0},
"Linear": {"total": 0.0, "count": 0},
},
"token_counts": {
"visual_tokens": 0,
"language_input_tokens_raw": 0,
"llm_total_input_tokens": 0,
"language_output_tokens": 0,
},
"data_point_count": 0 # Number of image-text pairs processed
}
# Initialize task-level statistics accumulators for CANDIDATE
cand_total_stats = {
"total_inference_time_seconds": 0.0,
"module_inference_times": {
"Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0},
"PatchMerger": {"total": 0.0, "count": 0},
"Qwen2VLModel": {"total": 0.0, "count": 0},
"Linear": {"total": 0.0, "count": 0},
},
"token_counts": {
"visual_tokens": 0,
"language_input_tokens_raw": 0,
"llm_total_input_tokens": 0,
"language_output_tokens": 0,
},
"data_point_count": 0 # Number of image-text pairs processed
}
if dist.is_initialized():
dist.barrier()
print_master(f"\n--- Evaluating {dataset_name} ---")
query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry")
cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt")
dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl")
# New: Define distinct paths for query and candidate inference statistics output
query_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_inference_stats.json")
cand_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_inference_stats.json")
do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path)
do_cand = not os.path.exists(cand_embed_path)
if do_query or do_cand:
if data_args.data_basedir is not None:
# Construct full paths for data files if --data_basedir is provided
for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
if data_args.data_basedir and task_config.get(key):
task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# Pad datasets to be divisible by world_size before splitting
if dist.is_initialized():
padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size)
padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size)
eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size)
eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size)
else:
padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# --- 1. Compute Query Embeddings ---
if do_query:
print_master("Encoding queries...")
eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
eval_qry_loader = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers)
# Modified: capture batch_stats_list
query_embeds, gt_infos, qry_batch_stats, qry_img_masks = encode_embeddings(model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries for {dataset_name}")
# Accumulate query statistics
for batch_stat in qry_batch_stats:
batch_size = batch_stat["batch_size"]
query_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"]
for module_name, module_stats in batch_stat["module_inference_times"].items():
if module_name in query_total_stats["module_inference_times"]:
query_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"]
query_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"]
query_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size
query_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size
query_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size
query_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size
query_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items
query_embeds = query_embeds[:len(full_eval_qry_dataset)]
gt_infos = gt_infos[:len(full_eval_qry_dataset)]
if local_rank == 0:
with open(query_embed_path, 'wb') as f:
pickle.dump(query_embeds, f)
with open(dataset_info_path, 'w') as f:
for info in gt_infos:
f.write(json.dumps(info) + '\n')
print_master(f"Saved query embeddings to {query_embed_path}")
qry_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_img_token_masks.jsonl")
with open(qry_img_masks_path, 'w', encoding='utf-8') as f:
for i, m in enumerate(qry_img_masks[:len(full_eval_qry_dataset)]):
f.write(json.dumps({"index": i, "mask": m}, ensure_ascii=False) + "\n")
print_master(f"Saved query image token masks to {qry_img_masks_path}")
# Save query-specific inference statistics
if query_total_stats["data_point_count"] > 0:
final_query_stats = {
"task_name": dataset_name,
"encode_side": "query",
"data_point_count": query_total_stats["data_point_count"],
"inference_times": {
"total_inference_time_seconds": query_total_stats["total_inference_time_seconds"],
"avg_inference_time_per_item_seconds": query_total_stats["total_inference_time_seconds"] / query_total_stats["data_point_count"],
"module_average_times_per_call": {},
"module_total_times_seconds": {},
"module_calls_count": {},
},
"token_counts": {
"total_visual_tokens": query_total_stats["token_counts"]["visual_tokens"],
"avg_visual_tokens_per_item": query_total_stats["token_counts"]["visual_tokens"] / query_total_stats["data_point_count"],
"total_language_input_tokens_raw": query_total_stats["token_counts"]["language_input_tokens_raw"],
"avg_language_input_tokens_raw_per_item": query_total_stats["token_counts"]["language_input_tokens_raw"] / query_total_stats["data_point_count"],
"total_llm_total_input_tokens": query_total_stats["token_counts"]["llm_total_input_tokens"],
"avg_llm_total_input_tokens_per_item": query_total_stats["token_counts"]["llm_total_input_tokens"] / query_total_stats["data_point_count"],
"total_language_output_tokens": query_total_stats["token_counts"]["language_output_tokens"],
"avg_language_output_tokens_per_item": query_total_stats["token_counts"]["language_output_tokens"] / query_total_stats["data_point_count"],
}
}
for module_name, stats in query_total_stats["module_inference_times"].items():
final_query_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"]
final_query_stats["inference_times"]["module_calls_count"][module_name] = stats["count"]
if stats["count"] > 0:
final_query_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"]
else:
final_query_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0
with open(query_inference_stats_path, 'w', encoding='utf-8') as f:
json.dump(final_query_stats, f, ensure_ascii=False, indent=4)
print_master(f"Query inference statistics for {dataset_name} saved to: {query_inference_stats_path}")
else:
print_master(f"No query data processed for {dataset_name}, skipping query inference statistics output.")
if dist.is_initialized():
dist.barrier()
# --- 2. Compute Candidate Embeddings ---
if do_cand:
print_master("Encoding candidates...")
eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
eval_cand_loader = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers)
# Modified: capture batch_stats_list
cand_embeds, all_cand_ids, cand_batch_stats, cand_img_masks = encode_embeddings(model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates for {dataset_name}")
# Accumulate candidate statistics (similar logic as query)
for batch_stat in cand_batch_stats:
batch_size = batch_stat["batch_size"]
cand_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"]
for module_name, module_stats in batch_stat["module_inference_times"].items():
if module_name in cand_total_stats["module_inference_times"]:
cand_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"]
cand_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"]
cand_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size
cand_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size
cand_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size
cand_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size
cand_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items
cand_embeds = cand_embeds[:len(full_eval_cand_dataset)]
all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)]
if local_rank == 0:
cand_embed_dict = {cand_id: embed for cand_id, embed in zip(all_cand_ids, cand_embeds)}
with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f)
print_master(f"Saved candidate embeddings to {cand_embed_path}")
cand_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_img_token_masks.jsonl")
with open(cand_img_masks_path, 'w', encoding='utf-8') as f:
for cid, m in zip(all_cand_ids[:len(full_eval_cand_dataset)], cand_img_masks[:len(full_eval_cand_dataset)]):
f.write(json.dumps({"cand_id": str(cid), "mask": m}, ensure_ascii=False) + "\n")
print_master(f"Saved candidate image token masks to {cand_img_masks_path}")
# Save candidate-specific inference statistics
if cand_total_stats["data_point_count"] > 0:
final_cand_stats = {
"task_name": dataset_name,
"encode_side": "candidate",
"data_point_count": cand_total_stats["data_point_count"],
"inference_times": {
"total_inference_time_seconds": cand_total_stats["total_inference_time_seconds"],
"avg_inference_time_per_item_seconds": cand_total_stats["total_inference_time_seconds"] / cand_total_stats["data_point_count"],
"module_average_times_per_call": {},
"module_total_times_seconds": {},
"module_calls_count": {},
},
"token_counts": {
"total_visual_tokens": cand_total_stats["token_counts"]["visual_tokens"],
"avg_visual_tokens_per_item": cand_total_stats["token_counts"]["visual_tokens"] / cand_total_stats["data_point_count"],
"total_language_input_tokens_raw": cand_total_stats["token_counts"]["language_input_tokens_raw"],
"avg_language_input_tokens_raw_per_item": cand_total_stats["token_counts"]["language_input_tokens_raw"] / cand_total_stats["data_point_count"],
"total_llm_total_input_tokens": cand_total_stats["token_counts"]["llm_total_input_tokens"],
"avg_llm_total_input_tokens_per_item": cand_total_stats["token_counts"]["llm_total_input_tokens"] / cand_total_stats["data_point_count"],
"total_language_output_tokens": cand_total_stats["token_counts"]["language_output_tokens"],
"avg_language_output_tokens_per_item": cand_total_stats["token_counts"]["language_output_tokens"] / cand_total_stats["data_point_count"],
}
}
for module_name, stats in cand_total_stats["module_inference_times"].items():
final_cand_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"]
final_cand_stats["inference_times"]["module_calls_count"][module_name] = stats["count"]
if stats["count"] > 0:
final_cand_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"]
else:
final_cand_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0
with open(cand_inference_stats_path, 'w', encoding='utf-8') as f:
json.dump(final_cand_stats, f, ensure_ascii=False, indent=4)
print_master(f"Candidate inference statistics for {dataset_name} saved to: {cand_inference_stats_path}")
else:
print_master(f"No candidate data processed for {dataset_name}, skipping candidate inference statistics output.")
if dist.is_initialized():
dist.barrier()
# --- 3. Compute Scores (on master rank only) ---
score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json")
####################################################################################
pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl")
score_detail_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_details.jsonl") # 新文件,存相似度分数
def append_score_detail(score_detail_list, qid, ranked_indices, score_vector, cand_ids, labels):
"""追加一个 query 的候选分数详情"""
score_detail_list.append({
"qid": int(qid),
"cand_scores": [
{"cand_id": str(cand_ids[i]), "score": float(score_vector[i])}
for i in ranked_indices
],
"label": labels
})
####################################################################################
if local_rank == 0:
if os.path.exists(score_path):
try:
with open(score_path, "r") as f:
score_dict = json.load(f)
print_master(f"Score of {dataset_name} (loaded from previous run): {score_path}")
formatted = {k: f"{v:.4f}" for k, v in score_dict.items()}
print_master(formatted)
# No `continue` here, as we want to ensure other files are processed/generated
except Exception as e:
print_master(f"Failed to load score for {dataset_name}, proceeding to recompute. Error: {e}")
# Proceed with score computation if not loaded or failed to load
with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f)
with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f)
gt_infos = [json.loads(l) for l in open(dataset_info_path)]
pred_dicts = []
score_detail_dicts = []###################################
rank_against_all_candidates = task_config.get("eval_type", "global") == "global"
# if rank_against_all_candidates:
# cand_keys = list(cand_embed_dict.keys())
# cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys])
# # Handle late-interaction scoring
# if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H]
# qry_embed = torch.from_numpy(qry_embeds)
# cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds]
# scores = processor.score(qry_embed, cand_embeds, batch_size=64) # use ColPali score function
# ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()
# scores = scores.cpu().numpy()
# else: # Dense
# cosine_scores = np.dot(qry_embeds, cand_embeds.T)
# ranked_candids = np.argsort(-cosine_scores, axis=1)
#####################################################
if rank_against_all_candidates:
cand_keys = list(cand_embed_dict.keys())
cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys])
if qry_embeds.ndim == 3: # Late-interaction
qry_embed_t = torch.from_numpy(qry_embeds)
cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds]
sim_matrix = processor.score(qry_embed_t, cand_embeds_t, batch_size=64).cpu().numpy() # [N_q, N_c]
else: # Dense
sim_matrix = np.dot(qry_embeds, cand_embeds.T) # [N_q, N_c]
ranked_candids = np.argsort(-sim_matrix, axis=1)
#########################################################
for qid, (ranked_candid, gt_info) in tqdm(enumerate(zip(ranked_candids, gt_infos)), desc=f"Calculating scores for {dataset_name}"):
rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None
assert rel_scores is None or len(rel_docids) == len(rel_scores)
pred_dicts.append({
"prediction": [cand_keys[i] for i in ranked_candid],
"label": rel_docids,
"rel_scores": rel_scores,
})
################################# 新增:详细相似度字典
append_score_detail(score_detail_dicts, qid, ranked_candid, sim_matrix[qid], cand_keys, rel_docids)
########################################
# else:
# for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"):
# cand_embeds = np.stack([cand_embed_dict[key] for key in gt_info["cand_names"]])
# if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H]
# qry_embed = torch.from_numpy(np.array(qry_embed)).unsqueeze(0)
# cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds]
# scores = processor.score(qry_embed, cand_embeds, batch_size=1024) # use ColPali score function
# ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()[0]
# else:
# cosine_score = np.dot(qry_embed, cand_embeds.T)
# ranked_candids = np.argsort(-cosine_score)
# rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
# rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None
# assert rel_scores is None or len(rel_docids) == len(rel_scores)
# pred_dicts.append({
# "prediction": [gt_info["cand_names"][i] for i in ranked_candids],
# "label": rel_docids,
# "rel_scores": rel_scores,
# })
#######################################################################
else: # 非全局
for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"):
cand_ids_local = gt_info["cand_names"]
cand_embeds = np.stack([cand_embed_dict[key] for key in cand_ids_local])
if qry_embeds.ndim == 3: # Late-interaction
qry_embed_t = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # [1, Lq, H]
cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds]
sim_vec = processor.score(qry_embed_t, cand_embeds_t, batch_size=1024).cpu().numpy()[0] # [N_c]
else: # Dense
sim_vec = np.dot(qry_embed, cand_embeds.T) # [N_c]
ranked_indices = np.argsort(-sim_vec)
rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None
assert rel_scores is None or len(rel_docids) == len(rel_scores)
pred_dicts.append({
"prediction": [cand_ids_local[i] for i in ranked_indices],
"label": rel_docids,
"rel_scores": rel_scores,
})
# 新增:分数详情
append_score_detail(score_detail_dicts, qid, ranked_indices, sim_vec, cand_ids_local, rel_docids)
########################################## 保存预测和分数
with open(score_detail_path, "w") as f: # 新增
for detail in score_detail_dicts:
f.write(json.dumps(detail) + '\n')
print_master(f"Detailed score file saved to: {score_detail_path}")
metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
metrics = RankingMetrics(metrics_to_report)
score_dict = metrics.evaluate(pred_dicts)
formatted = {k: f"{v:.4f}" for k, v in score_dict.items()}
score_dict["num_pred"] = len(pred_dicts)
score_dict["num_data"] = len(gt_infos)
print_master(f"Score of {dataset_name}:")
print_master(formatted)
print_master(f"Outputting final score to: {score_path}")
with open(score_path, "w") as f:
json.dump(score_dict, f, indent=4)
with open(pred_path, "w") as f:
for pred in pred_dicts:
f.write(json.dumps(pred) + '\n')
####################################################################
score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json")
pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl")
metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
metrics = RankingMetrics(metrics_to_report)
score_dict = metrics.evaluate(pred_dicts)
formatted = {k: f"{v:.4f}" for k, v in score_dict.items()}
score_dict["num_pred"] = len(pred_dicts)
score_dict["num_data"] = len(gt_infos)
print_master(f"Score of {dataset_name}:")
print_master(formatted)
print_master(f"Outputting final score to: {score_path}")
with open(score_path, "w") as f:
json.dump(score_dict, f, indent=4)
with open(pred_path, "w") as f:
for pred in pred_dicts:
f.write(json.dumps(pred) + '\n')
if __name__ == '__main__':
main()