# import datetime # import logging # import json # import random # import time # import numpy as np # import os # import pickle # import sys # import torch # import torch.distributed as dist # import torch.nn.functional as F # import yaml # from torch.utils.data import DataLoader # from tqdm import tqdm # from transformers import HfArgumentParser, AutoConfig # from datasets import Dataset, concatenate_datasets # from datasets.distributed import split_dataset_by_node # from src.arguments import ModelArguments, DataArguments, TrainingArguments # from src.data.collator.eval_collator import MultimodalEvalDataCollator # from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset # from src.eval_utils.metrics import RankingMetrics # from src.model.model import MMEBModel # from src.model.processor import get_backbone_name, load_processor, COLPALI # from src.utils import batch_to_device, print_rank, print_master # import multiprocessing # from multiprocessing import Pool, cpu_count # logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') # logger = logging.getLogger(__name__) # ############################################### # # 计时开始 # def start_timer(name, timing_dict): # timing_dict[name] = time.time() # # 计时结束 # def end_timer(name, timing_dict): # end_time = time.time() # if name in timing_dict: # timing_dict[name] = end_time - timing_dict[name] # # # 放在 main 函数之前,或者单独放在 utils.py 中 # # def register_hooks(model, timing_dict): # # # --- vision_encoder hook --- # # def vision_forward_hook(module, input, output): # # print_master(f"[vision_encoder] output shape: {output.shape}") # # print_master(f"[vision_encoder] num_image_tokens: {output.shape[1]}") # # start_timer("vision_encoder", timing_dict) # # def vision_forward_post_hook(module, input, output): # # end_timer("vision_encoder", timing_dict) # # model.encoder.visual.register_forward_hook(vision_forward_hook) # # model.encoder.visual.register_forward_hook(vision_forward_post_hook) # # # --- merger hook --- # # def merger_forward_hook(module, input, output): # # print_master(f"[merger] before merger - input shape: {input[0].shape}") # # print_master(f"[merger] before merger - num_tokens: {input[0].shape[1]}") # # start_timer("merger", timing_dict) # # def merger_forward_post_hook(module, input, output): # # print_master(f"[merger] after merger - output shape: {output.shape}") # # print_master(f"[merger] after merger - num_tokens: {output.shape[1]}") # # end_timer("merger", timing_dict) # # if hasattr(model.encoder.visual, 'merger'): # # model.encoder.visual.merger.register_forward_hook(merger_forward_hook) # # model.encoder.visual.merger.register_forward_hook(merger_forward_post_hook) # # # --- decoder hook --- # # def decoder_forward_hook(module, input, output): # # # 这里更新为接收 input 和 output 参数 # # if isinstance(input, tuple) and len(input) > 0: # # print_master(f"[llm_decoder] input shape: {input[0].shape}") # # print_master(f"[llm_decoder] total_tokens (image+text): {input[0].shape[1]}") # # start_timer("llm_decoder", timing_dict) # # def decoder_forward_post_hook(module, input, output): # # end_timer("llm_decoder", timing_dict) # # model.encoder.model.register_forward_hook(decoder_forward_hook) # # model.encoder.model.register_forward_hook(decoder_forward_post_hook) # # # --- lm_head hook --- # # def lm_head_forward_hook(module, input, output): # # start_timer("lm_head", timing_dict) # # def lm_head_forward_post_hook(module, input, output): # # end_timer("lm_head", timing_dict) # # model.encoder.lm_head.register_forward_hook(lm_head_forward_hook) # # model.encoder.lm_head.register_forward_hook(lm_head_forward_post_hook) # def register_timing_hooks(model, timing_dict): # def make_hooks(name): # def pre_hook(module, input): # timing_dict[f"{name}_start"] = time.time() # def forward_hook(module, input, output): # elapsed = time.time() - timing_dict[f"{name}_start"] # timing_dict[name] = elapsed # 记录时间 # print_master(f"[{name}] took {elapsed * 1000:.2f} ms") # return pre_hook, forward_hook # # vision encoder # pre, post = make_hooks("vision_encoder") # model.encoder.visual.register_forward_pre_hook(pre) # model.encoder.visual.register_forward_hook(post) # # merger # if hasattr(model.encoder.visual, 'merger'): # pre, post = make_hooks("merger") # model.encoder.visual.merger.register_forward_pre_hook(pre) # model.encoder.visual.merger.register_forward_hook(post) # # decoder # pre, post = make_hooks("llm_decoder") # model.encoder.model.register_forward_pre_hook(pre) # model.encoder.model.register_forward_hook(post) # # lm_head # pre, post = make_hooks("lm_head") # model.encoder.lm_head.register_forward_pre_hook(pre) # model.encoder.lm_head.register_forward_hook(post) # ##################################################### # def pad_dataset_to_divisible(dataset, world_size): # num_samples = len(dataset) # if num_samples % world_size == 0: # return dataset, num_samples # num_to_add = world_size - (num_samples % world_size) # padded_size = num_samples + num_to_add # padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) # padded_dataset = concatenate_datasets([dataset, padding_data]) # return padded_dataset, padded_size # def encode_embeddings( # model: MMEBModel, # loader: DataLoader, # training_args: TrainingArguments, # model_args: ModelArguments, # full_dataset: Dataset, # encode_side: str, # description: str = "Encoding", # timing_dict: dict | None = None # ) -> tuple[np.ndarray, list]: # """ # Encodes embeddings for a given dataset using the model, handling both standard and # late-interaction models in a DDP-safe manner. # """ # local_rank = dist.get_rank() if dist.is_initialized() else 0 # world_size = dist.get_world_size() if dist.is_initialized() else 1 # # Check if the model is a late-interaction type # is_late_interaction = (model_args.model_backbone == COLPALI) # local_embeds = [] # local_gt_infos = [] # local_max_len = 0 # model.eval() # with torch.no_grad(): # for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): # inputs = batch_to_device(inputs, training_args.device) # with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): # # Determine if encoding query or target based on available keys # if encode_side == "qry": # output = model(qry=inputs) # reps = output["qry_reps"].detach() # local_gt_infos.extend(dataset_info) # to retain all information per query # else: # output = model(tgt=inputs) # reps = output["tgt_reps"].detach() # local_gt_infos.extend([info["cand_name"] for info in dataset_info]) # to retain ground-truth labels # if is_late_interaction and reps.dim() == 3: # local_max_len = max(local_max_len, reps.shape[1]) # local_embeds.append(reps) # if not local_embeds: # # Handle cases where a rank gets no data # return np.array([]), [] # # === DDP Synchronization and Padding for Late-Interaction Models === # if is_late_interaction: # if dist.is_initialized(): # # 1. Find the global maximum sequence length across all ranks # local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device) # dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX) # global_max_len = local_max_len_tensor.item() # else: # global_max_len = local_max_len # # 2. Pad all local embeddings to the global max length # padded_embeds = [] # for reps_batch in local_embeds: # if reps_batch.dim() == 3: # B, L, H = reps_batch.shape # padding_size = global_max_len - L # padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0) # padded_embeds.append(padded_batch) # else: # Should not happen if model is consistently late-interaction # padded_embeds.append(reps_batch) # embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous() # else: # Standard dense models # embeds_tensor = torch.cat(local_embeds, dim=0).contiguous() # # === Gather embeddings and keys from all ranks === # if dist.is_initialized() and full_dataset.num_rows >= world_size: # print_master(f"Gathering {encode_side} embeddings across all ranks...") # # Use the more efficient all_gather_into_tensor for tensors # output_shape = list(embeds_tensor.shape) # output_shape[0] = full_dataset.num_rows # embeds_tensor = embeds_tensor.to(training_args.device) # gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) # dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor) # final_embeddings = gathered_embeds_tensor.cpu().float().numpy() # # Gather metadata, for which all_gather_object is appropriate # gathered_gt_infos = [None for _ in range(world_size)] # dist.all_gather_object(gathered_gt_infos, local_gt_infos) # all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys] # else: # all_gt_infos = local_gt_infos # final_embeddings = embeds_tensor.cpu().float().numpy() # print_master(f"Timing results for {description}:") # for k, v in timing_dict.items(): # if not k.startswith('_'): # print_master(f" {k}: {v:.4f} sec") # return final_embeddings, all_gt_infos # def main(): # if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): # dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) # local_rank = dist.get_rank() if dist.is_initialized() else 0 # world_size = dist.get_world_size() if dist.is_initialized() else 1 # # DEBUG PRINTS for Distributed Setup # print_master("Distributed init debug info:") # print_master(f"RANK: {os.environ.get('RANK')}") # print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}") # print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}") # print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}") # print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}") # if dist.is_initialized(): # print_rank(f"dist.get_rank(): {dist.get_rank()}") # print_rank(f"dist.get_world_size(): {dist.get_world_size()}") # for arg in sys.argv: # if arg.startswith("--local-rank="): # rank = arg.split("=")[1] # sys.argv.remove(arg) # sys.argv.append('--local_rank') # sys.argv.append(rank) # parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) # model_args, data_args, training_args = parser.parse_args_into_dataclasses() # model_args: ModelArguments # data_args: DataArguments # training_args: TrainingArguments # os.makedirs(data_args.encode_output_path, exist_ok=True) # # --- Model Loading --- # hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # if not getattr(model_args, "model_backbone", None): # model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) # setattr(model_args, 'model_backbone', model_backbone) # setattr(training_args, 'model_backbone', model_backbone) # print_master(f'Model Backbone: {model_args.model_backbone}') # # --- DDP-Safe Model Loading --- # # Step 1: Only the master process (rank 0) downloads the model. # if local_rank == 0: # processor = load_processor(model_args, data_args) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...") # # Step 2: All processes wait here. The non-master processes will pause # # until the master process (rank 0) finishes downloading and exits this barrier. # if torch.distributed.is_initialized(): # torch.distributed.barrier() # # Step 3: Now that the model is cached, the non-master processes load it from the local cache. # if local_rank != 0: # print_rank(f"Loading the model from cache...") # processor = load_processor(model_args, data_args) # time.sleep(random.randint(2 * local_rank, 3 * local_rank)) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # model.eval() # model = model.to(training_args.device, dtype=torch.bfloat16) # with open(data_args.dataset_config, 'r') as yaml_file: # dataset_configs = yaml.safe_load(yaml_file) # ############################################################################# # import time # timing_dict = {} # register_hooks(model, timing_dict) # 注册 hooks,开始计时 # ############################################################################## # # --- Main Evaluation Loop --- # for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()): # # 0. load dataset # if dist.is_initialized(): # dist.barrier() # print_master(f"--- Evaluating {dataset_name} ---") # query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry") # cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt") # dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") # do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path) # do_cand = not os.path.exists(cand_embed_path) # if do_query or do_cand: # if data_args.data_basedir is not None: # # Construct full paths for data files if --data_basedir is provided # for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: # if data_args.data_basedir and task_config.get(key): # task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) # full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) # full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) # eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # # Pad datasets to be divisible by world_size before splitting # if dist.is_initialized(): # padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size) # padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size) # eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size) # eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size) # else: # padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # # --- 1. Compute Query Embeddings --- # if do_query: # print_master("Encoding queries...") # eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") # eval_qry_loader = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers) # query_embeds, gt_infos = encode_embeddings(model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries for {dataset_name}", timing_dict=timing_dict) # query_embeds = query_embeds[:len(full_eval_qry_dataset)] # world_size>1, trim the padded data points # gt_infos = gt_infos[:len(full_eval_qry_dataset)] # if local_rank == 0: # with open(query_embed_path, 'wb') as f: # pickle.dump(query_embeds, f) # with open(dataset_info_path, 'w') as f: # for info in gt_infos: # f.write(json.dumps(info) + '\n') # print_master(f"Saved query embeddings to {query_embed_path}") # if dist.is_initialized(): # dist.barrier() # # --- 2. Compute Candidate Embeddings --- # if do_cand: # print_master("Encoding candidates...") # eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") # eval_cand_loader = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) # cand_embeds, all_cand_ids = encode_embeddings(model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates for {dataset_name}", timing_dict=timing_dict) # cand_embeds = cand_embeds[:len(full_eval_cand_dataset)] # world_size>1, trim the padded data points # all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)] # if local_rank == 0: # cand_embed_dict = {cand_id: embed for cand_id, embed in zip(all_cand_ids, cand_embeds)} # with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f) # print_master(f"Saved candidate embeddings to {cand_embed_path}") # if dist.is_initialized(): # dist.barrier() # # --- 3. Compute Scores (on master rank only) --- # if local_rank == 0: # score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") # if os.path.exists(score_path): # try: # with open(score_path, "r") as f: # score_dict = json.load(f) # print_master(f"Score of {dataset_name} (loaded from previous run): {score_path}") # formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} # print_master(formatted) # continue # except Exception as e: # print_master(f"Failed to load score for {dataset_name}, skipping {dataset_name}") # with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f) # with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f) # gt_infos = [json.loads(l) for l in open(dataset_info_path)] # pred_dicts = [] # rank_against_all_candidates = task_config.get("eval_type", "global") == "global" # if rank_against_all_candidates: # cand_keys = list(cand_embed_dict.keys()) # cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys]) # # Handle late-interaction scoring # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] # qry_embed = torch.from_numpy(qry_embeds) # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] # scores = processor.score(qry_embed, cand_embeds, batch_size=64) # use ColPali score function # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist() # else: # Dense # cosine_scores = np.dot(qry_embeds, cand_embeds.T) # ranked_candids = np.argsort(-cosine_scores, axis=1) # for qid, (ranked_candid, gt_info) in tqdm(enumerate(zip(ranked_candids, gt_infos)), desc=f"Calculating scores for {dataset_name}"): # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None # assert rel_scores is None or len(rel_docids) == len(rel_scores) # pred_dicts.append({ # "prediction": [cand_keys[i] for i in ranked_candid], # "label": rel_docids, # "rel_scores": rel_scores, # }) # else: # for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"): # cand_embeds = np.stack([cand_embed_dict[key] for key in gt_info["cand_names"]]) # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] # qry_embed = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] # scores = processor.score(qry_embed, cand_embeds, batch_size=1024) # use ColPali score function # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()[0] # else: # cosine_score = np.dot(qry_embed, cand_embeds.T) # ranked_candids = np.argsort(-cosine_score) # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None # assert rel_scores is None or len(rel_docids) == len(rel_scores) # pred_dicts.append({ # "prediction": [gt_info["cand_names"][i] for i in ranked_candids], # "label": rel_docids, # "rel_scores": rel_scores, # }) # score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") # pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl") # metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] # metrics = RankingMetrics(metrics_to_report) # score_dict = metrics.evaluate(pred_dicts) # formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} # score_dict["num_pred"] = len(pred_dicts) # score_dict["num_data"] = len(gt_infos) # print_master(f"Score of {dataset_name}:") # print_master(formatted) # print_master(f"Outputting final score to: {score_path}") # with open(score_path, "w") as f: # json.dump(score_dict, f, indent=4) # with open(pred_path, "w") as f: # for pred in pred_dicts: # f.write(json.dumps(pred) + '\n') # if __name__ == "__main__": # main() ################################################################################################### #直接打印输出对应模块时间和token数量 # import datetime # import logging # import json # import random # import time # import numpy as np # import os # import pickle # import sys # import torch # import torch.distributed as dist # import torch.nn.functional as F # import yaml # import transformers # from torch.utils.data import DataLoader # from tqdm import tqdm # from transformers import HfArgumentParser, AutoConfig, AutoTokenizer#, Qwen2VLForConditionalGeneration # from datasets import Dataset, concatenate_datasets # from datasets.distributed import split_dataset_by_node # from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src # from src.arguments import ModelArguments, DataArguments, TrainingArguments # from src.data.collator.eval_collator import MultimodalEvalDataCollator # from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset # from src.eval_utils.metrics import RankingMetrics # from src.model.model import MMEBModel # from src.model.processor import get_backbone_name, load_processor, COLPALI # from src.utils import batch_to_device, print_rank, print_master # import multiprocessing # from multiprocessing import Pool, cpu_count # logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') # logger = logging.getLogger(__name__) # # --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) --- # timing_info = {} # token_info = { # "vision_tokens": 0, # "text_input_tokens": 0, # Refers to the original text token count # "text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0. # "total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text) # } # # --- Hook Functions Definition --- # def timing_pre_hook(module, input): # module_id = id(module) # if module_id not in timing_info: # timing_info[module_id] = [] # timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__)) # def timing_post_hook(module, input, output): # module_id = id(module) # if module_id not in timing_info: # print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})") # return # timing_info[module_id].append((time.time(), 'post', module.__class__.__name__)) # # Collect vision token count (only from Vision Transformer module's post hook) # module_name = module.__class__.__name__ # if "vision" in module_name.lower() and "transformer" in module_name.lower(): # if isinstance(output, torch.Tensor): # token_info["vision_tokens"] = output.shape[0] # elif hasattr(output, 'last_hidden_state'): # token_info["vision_tokens"] = output.last_hidden_state.shape[1] # # --- Hook Functions Definition --- # # (timing_pre_hook and timing_post_hook remain as previously corrected with debug prints) # def register_model_hooks(model): # registered_modules = [] # core_model = model # print_master(f"DEBUG: Initial model type in register_model_hooks: {type(model)}") # if hasattr(model, 'encoder') and model.encoder is not None: # print_master(f"DEBUG: model has 'encoder' attribute. Type of model.encoder: {type(model.encoder)}") # # 使用从 'src' 路径导入的 Qwen2VLForConditionalGeneration 进行检查 # if isinstance(model.encoder, _Qwen2VLForConditionalGeneration_src): # print_master("Detected MMEBModel structure, registering hooks on model.encoder's sub-modules.") # core_model = model.encoder # else: # print_master(f"WARNING: model.encoder is not an instance of _Qwen2VLForConditionalGeneration_src. Its type is {type(model.encoder)}. Hooks will be registered on top-level model if applicable.") # else: # print_master("WARNING: Model structure does not have an 'encoder' attribute. Registering hooks directly on top-level modules.") # # Vision module # if hasattr(core_model, 'visual') and core_model.visual is not None: # vision_module = core_model.visual # vision_module.register_forward_pre_hook(timing_pre_hook) # vision_module.register_forward_hook(timing_post_hook) # registered_modules.append(vision_module) # print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).") # # Merger module (if inside visual) - it's part of the vision component # if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None: # merger_module = core_model.visual.merger # merger_module.register_forward_pre_hook(timing_pre_hook) # merger_module.register_forward_hook(timing_post_hook) # registered_modules.append(merger_module) # print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).") # # Language model body # if hasattr(core_model, 'model') and core_model.model is not None: # llm_main_module = core_model.model # llm_main_module.register_forward_pre_hook(timing_pre_hook) # llm_main_module.register_forward_hook(timing_post_hook) # registered_modules.append(llm_main_module) # print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).") # # LM Head # if hasattr(core_model, 'lm_head') and core_model.lm_head is not None: # lm_head_module = core_model.lm_head # lm_head_module.register_forward_pre_hook(timing_pre_hook) # lm_head_module.register_forward_hook(timing_post_hook) # registered_modules.append(lm_head_module) # print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).") # if not registered_modules: # print_master("Warning: No major modules found for hook registration. Check model architecture.") # return registered_modules # def pad_dataset_to_divisible(dataset, world_size): # num_samples = len(dataset) # if num_samples % world_size == 0: # return dataset, num_samples # num_to_add = world_size - (num_samples % world_size) # padded_size = num_samples + num_to_add # padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) # padded_dataset = concatenate_datasets([dataset, padding_data]) # return padded_dataset, padded_size # def encode_embeddings( # model: MMEBModel, # loader: DataLoader, # training_args: TrainingArguments, # model_args: ModelArguments, # full_dataset: Dataset, # encode_side: str, # description: str = "Encoding" # ) -> tuple[np.ndarray, list]: # """ # Encodes embeddings for a given dataset using the model, handling both standard and # late-interaction models in a DDP-safe manner. # """ # local_rank = dist.get_rank() if dist.is_initialized() else 0 # world_size = dist.get_world_size() if dist.is_initialized() else 1 # # Check if the model is a late-interaction type # is_late_interaction = (model_args.model_backbone == COLPALI) # local_embeds = [] # local_gt_infos = [] # local_max_len = 0 # model.eval() # # Register hooks for the model once per encode_embeddings call # # This assumes `model` is the MMEBModel instance that wraps the actual HuggingFace model # # You might need to adjust this if MMEBModel internally manages multiple sub-models # registered_hooks = register_model_hooks(model) # <--- FIX: Assign the return value here # # Initialize a tokenizer for text token counting (needs to be from the same model path) # temp_tokenizer = AutoTokenizer.from_pretrained(model_args.model_name) # with torch.no_grad(): # for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): # # --- Reset statistics for each inference pass --- # timing_info.clear() # token_info["vision_tokens"] = 0 # token_info["text_input_tokens"] = 0 # token_info["text_output_tokens"] = 0 # Encoding doesn't generate text output tokens # token_info["total_llm_input_tokens"] = 0 # inputs = batch_to_device(inputs, training_args.device) # with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): # # Determine if encoding query or target based on available keys # # This is where the forward pass happens, triggering hooks # start_inference_time = time.time() # if encode_side == "qry": # output = model(qry=inputs) # reps = output["qry_reps"].detach() # local_gt_infos.extend(dataset_info) # to retain all information per query # else: # output = model(tgt=inputs) # reps = output["tgt_reps"].detach() # local_gt_infos.extend([info["cand_name"] for info in dataset_info]) # to retain ground-truth labels # end_inference_time = time.time() # # --- Update total LLM input tokens after the model call --- # # This requires knowing which part of `inputs` corresponds to the LLM's full input. # # Assuming `inputs.input_ids` directly goes into the LLM part of Qwen2-VL. # if 'input_ids' in inputs and inputs['input_ids'] is not None: # token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1] # # Approximation for text_input_tokens (if not explicitly available from collator) # # This assumes visual tokens are a prefix and the rest are text/special tokens. # token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"] # # Ensure it's not negative # token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"]) # # --- Print Inference Timing and Token Statistics per Batch --- # print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---") # print_rank(f"Batch Inference took: {end_inference_time - start_inference_time:.4f} seconds") # # Calculate and print module timings # print_rank("--- Module Inference Timing Statistics ---") # for module_obj in registered_hooks: # module_id = id(module_obj) # module_name = module_obj.__class__.__name__ # times = timing_info.get(module_id, []) # durations = [] # pre_times = {} # Store start times for each pre-hook # for t, event_type, _ in times: # if event_type == 'pre': # pre_times[module_id] = t # elif event_type == 'post' and module_id in pre_times: # duration = t - pre_times.pop(module_id) # durations.append(duration) # if durations: # print_rank(f"**{module_name}**: Total: {sum(durations):.6f}s, Count: {len(durations)}, Avg: {sum(durations)/len(durations):.6f}s") # else: # print_rank(f"**{module_name}**: No complete timing data found for this batch.") # print_rank("--- Token Count Statistics ---") # print_rank(f"**视觉 token 数量**: {token_info['vision_tokens']}") # print_rank(f"**语言输入 token 数量 (仅原始文本)**: {token_info['text_input_tokens']}") # print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {token_info['total_llm_input_tokens']}") # print_rank(f"**语言输出 token 数量**: {token_info['text_output_tokens']}") # Will be 0 for encoding # if is_late_interaction and reps.dim() == 3: # local_max_len = max(local_max_len, reps.shape[1]) # local_embeds.append(reps) # if not local_embeds: # # Handle cases where a rank gets no data # return np.array([]), [] # # === DDP Synchronization and Padding for Late-Interaction Models === # if is_late_interaction: # if dist.is_initialized(): # # 1. Find the global maximum sequence length across all ranks # local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device) # dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX) # global_max_len = local_max_len_tensor.item() # else: # global_max_len = local_max_len # # 2. Pad all local embeddings to the global max length # padded_embeds = [] # for reps_batch in local_embeds: # if reps_batch.dim() == 3: # B, L, H = reps_batch.shape # padding_size = global_max_len - L # padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0) # padded_embeds.append(padded_batch) # else: # Should not happen if model is consistently late-interaction # padded_embeds.append(reps_batch) # embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous() # else: # Standard dense models # embeds_tensor = torch.cat(local_embeds, dim=0).contiguous() # # === Gather embeddings and keys from all ranks === # if dist.is_initialized() and full_dataset.num_rows >= world_size: # print_master(f"Gathering {encode_side} embeddings across all ranks...") # # Use the more efficient all_gather_into_tensor for tensors # output_shape = list(embeds_tensor.shape) # output_shape[0] = full_dataset.num_rows # embeds_tensor = embeds_tensor.to(training_args.device) # gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) # dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor) # final_embeddings = gathered_embeds_tensor.cpu().float().numpy() # # Gather metadata, for which all_gather_object is appropriate # gathered_gt_infos = [None for _ in range(world_size)] # dist.all_gather_object(gathered_gt_infos, local_gt_infos) # all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys] # else: # all_gt_infos = local_gt_infos # final_embeddings = embeds_tensor.cpu().float().numpy() # return final_embeddings, all_gt_infos # def main(): # if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): # dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) # local_rank = dist.get_rank() if dist.is_initialized() else 0 # world_size = dist.get_world_size() if dist.is_initialized() else 1 # # DEBUG PRINTS for Distributed Setup # print_master("Distributed init debug info:") # print_master(f"RANK: {os.environ.get('RANK')}") # print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}") # print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}") # print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}") # print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}") # if dist.is_initialized(): # print_rank(f"dist.get_rank(): {dist.get_rank()}") # print_rank(f"dist.get_world_size(): {dist.get_world_size()}") # for arg in sys.argv: # if arg.startswith("--local-rank="): # rank = arg.split("=")[1] # sys.argv.remove(arg) # sys.argv.append('--local_rank') # sys.argv.append(rank) # parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) # model_args, data_args, training_args = parser.parse_args_into_dataclasses() # model_args: ModelArguments # data_args: DataArguments # training_args: TrainingArguments # os.makedirs(data_args.encode_output_path, exist_ok=True) # # --- Model Loading --- # hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # if not getattr(model_args, "model_backbone", None): # model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) # setattr(model_args, 'model_backbone', model_backbone) # setattr(training_args, 'model_backbone', model_backbone) # print_master(f'Model Backbone: {model_args.model_backbone}') # # --- DDP-Safe Model Loading --- # # Step 1: Only the master process (rank 0) downloads the model. # if local_rank == 0: # processor = load_processor(model_args, data_args) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...") # # Step 2: All processes wait here. The non-master processes will pause # # until the master process (rank 0) finishes downloading and exits this barrier. # if torch.distributed.is_initialized(): # torch.distributed.barrier() # # Step 3: Now that the model is cached, the non-master processes load it from the local cache. # if local_rank != 0: # print_rank(f"Loading the model from cache...") # processor = load_processor(model_args, data_args) # time.sleep(random.randint(2 * local_rank, 3 * local_rank)) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # model.eval() # model = model.to(training_args.device, dtype=torch.bfloat16) # with open(data_args.dataset_config, 'r') as yaml_file: # dataset_configs = yaml.safe_load(yaml_file) # # --- Main Evaluation Loop --- # for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()): # # 0. load dataset # if dist.is_initialized(): # dist.barrier() # print_master(f"--- Evaluating {dataset_name} ---") # query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry") # cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt") # dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") # do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path) # do_cand = not os.path.exists(cand_embed_path) # if do_query or do_cand: # if data_args.data_basedir is not None: # # Construct full paths for data files if --data_basedir is provided # for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: # if data_args.data_basedir and task_config.get(key): # task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) # full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) # full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) # eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # # Pad datasets to be divisible by world_size before splitting # if dist.is_initialized(): # padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size) # padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size) # eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size) # eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size) # else: # padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # # --- 1. Compute Query Embeddings --- # if do_query: # print_master("Encoding queries...") # eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") # eval_qry_loader = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers) # query_embeds, gt_infos = encode_embeddings(model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries for {dataset_name}") # query_embeds = query_embeds[:len(full_eval_qry_dataset)] # world_size>1, trim the padded data points # gt_infos = gt_infos[:len(full_eval_qry_dataset)] # if local_rank == 0: # with open(query_embed_path, 'wb') as f: # pickle.dump(query_embeds, f) # with open(dataset_info_path, 'w') as f: # for info in gt_infos: # f.write(json.dumps(info) + '\n') # print_master(f"Saved query embeddings to {query_embed_path}") # if dist.is_initialized(): # dist.barrier() # # --- 2. Compute Candidate Embeddings --- # if do_cand: # print_master("Encoding candidates...") # eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") # eval_cand_loader = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) # cand_embeds, all_cand_ids = encode_embeddings(model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates for {dataset_name}") # cand_embeds = cand_embeds[:len(full_eval_cand_dataset)] # world_size>1, trim the padded data points # all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)] # if local_rank == 0: # cand_embed_dict = {cand_id: embed for cand_id, embed in zip(all_cand_ids, cand_embeds)} # with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f) # print_master(f"Saved candidate embeddings to {cand_embed_path}") # if dist.is_initialized(): # dist.barrier() # # --- 3. Compute Scores (on master rank only) --- # if local_rank == 0: # score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") # if os.path.exists(score_path): # try: # with open(score_path, "r") as f: # score_dict = json.load(f) # print_master(f"Score of {dataset_name} (loaded from previous run): {score_path}") # formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} # print_master(formatted) # continue # except Exception as e: # print_master(f"Failed to load score for {dataset_name}, skipping {dataset_name}") # with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f) # with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f) # gt_infos = [json.loads(l) for l in open(dataset_info_path)] # pred_dicts = [] # rank_against_all_candidates = task_config.get("eval_type", "global") == "global" # if rank_against_all_candidates: # cand_keys = list(cand_embed_dict.keys()) # cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys]) # # Handle late-interaction scoring # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] # qry_embed = torch.from_numpy(qry_embeds) # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] # scores = processor.score(qry_embed, cand_embeds, batch_size=64) # use ColPali score function # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist() # else: # Dense # cosine_scores = np.dot(qry_embeds, cand_embeds.T) # ranked_candids = np.argsort(-cosine_scores, axis=1) # for qid, (ranked_candid, gt_info) in tqdm(enumerate(zip(ranked_candids, gt_infos)), desc=f"Calculating scores for {dataset_name}"): # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None # assert rel_scores is None or len(rel_docids) == len(rel_scores) # pred_dicts.append({ # "prediction": [cand_keys[i] for i in ranked_candid], # "label": rel_docids, # "rel_scores": rel_scores, # }) # else: # for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"): # cand_embeds = np.stack([cand_embed_dict[key] for key in gt_info["cand_names"]]) # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] # qry_embed = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] # scores = processor.score(qry_embed, cand_embeds, batch_size=1024) # use ColPali score function # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()[0] # else: # cosine_score = np.dot(qry_embed, cand_embeds.T) # ranked_candids = np.argsort(-cosine_score) # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None # assert rel_scores is None or len(rel_docids) == len(rel_scores) # pred_dicts.append({ # "prediction": [gt_info["cand_names"][i] for i in ranked_candids], # "label": rel_docids, # "rel_scores": rel_scores, # }) # score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") # pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl") # metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] # metrics = RankingMetrics(metrics_to_report) # score_dict = metrics.evaluate(pred_dicts) # formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} # score_dict["num_pred"] = len(pred_dicts) # score_dict["num_data"] = len(gt_infos) # print_master(f"Score of {dataset_name}:") # print_master(formatted) # print_master(f"Outputting final score to: {score_path}") # with open(score_path, "w") as f: # json.dump(score_dict, f, indent=4) # with open(pred_path, "w") as f: # for pred in pred_dicts: # f.write(json.dumps(pred) + '\n') # if __name__ == "__main__": # main() ################################################################################################## ################################################################################################### #将每个任务的平均值和总值输出保存到文件中 # import datetime # import logging # import json # import random # import time # import numpy as np # import os # import pickle # import sys # import torch # import torch.distributed as dist # import torch.nn.functional as F # import yaml # import transformers # from torch.utils.data import DataLoader # from tqdm import tqdm # from transformers import HfArgumentParser, AutoConfig, AutoTokenizer # from datasets import Dataset, concatenate_datasets # from datasets.distributed import split_dataset_by_node # from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src # from src.arguments import ModelArguments, DataArguments, TrainingArguments # from src.data.collator.eval_collator import MultimodalEvalDataCollator # from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset # from src.eval_utils.metrics import RankingMetrics # from src.model.model import MMEBModel # from src.model.processor import get_backbone_name, load_processor, COLPALI # from src.utils import batch_to_device, print_rank, print_master # logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') # logger = logging.getLogger(__name__) # # --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) --- # timing_info = {} # token_info = { # "vision_tokens": 0, # "text_input_tokens": 0, # Refers to the original text token count # "text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0. # "total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text) # } # # --- Hook Functions Definition --- # def timing_pre_hook(module, input): # module_id = id(module) # if module_id not in timing_info: # timing_info[module_id] = [] # timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__)) # def timing_post_hook(module, input, output): # module_id = id(module) # if module_id not in timing_info: # # print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})") # return # timing_info[module_id].append((time.time(), 'post', module.__class__.__name__)) # # Collect vision token count (only from Vision Transformer module's post hook) # module_name = module.__class__.__name__ # if "vision" in module_name.lower() and "transformer" in module_name.lower(): # if isinstance(output, torch.Tensor): # token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim) # elif hasattr(output, 'last_hidden_state'): # token_info["vision_tokens"] = output.last_hidden_state.shape[1] # def register_model_hooks(model): # registered_modules = [] # core_model = model # # print_master(f"DEBUG: Initial model type in register_model_hooks: {type(model)}") # if hasattr(model, 'encoder') and model.encoder is not None: # # print_master(f"DEBUG: model has 'encoder' attribute. Type of model.encoder: {type(model.encoder)}") # # 使用从 'src' 路径导入的 Qwen2VLForConditionalGeneration 进行检查 # if isinstance(model.encoder, _Qwen2VLForConditionalGeneration_src): # # print_master("Detected MMEBModel structure, registering hooks on model.encoder's sub-modules.") # core_model = model.encoder # else: # print_master(f"WARNING: model.encoder is not an instance of _Qwen2VLForConditionalGeneration_src. Its type is {type(model.encoder)}. Hooks will be registered on top-level model if applicable.") # else: # print_master("WARNING: Model structure does not have an 'encoder' attribute. Registering hooks directly on top-level modules.") # # Vision module # if hasattr(core_model, 'visual') and core_model.visual is not None: # vision_module = core_model.visual # vision_module.register_forward_pre_hook(timing_pre_hook) # vision_module.register_forward_hook(timing_post_hook) # registered_modules.append(vision_module) # print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).") # # Merger module (if inside visual) - it's part of the vision component # if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None: # merger_module = core_model.visual.merger # merger_module.register_forward_pre_hook(timing_pre_hook) # merger_module.register_forward_hook(timing_post_hook) # registered_modules.append(merger_module) # print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).") # # Language model body # if hasattr(core_model, 'model') and core_model.model is not None: # llm_main_module = core_model.model # llm_main_module.register_forward_pre_hook(timing_pre_hook) # llm_main_module.register_forward_hook(timing_post_hook) # registered_modules.append(llm_main_module) # print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).") # # LM Head # if hasattr(core_model, 'lm_head') and core_model.lm_head is not None: # lm_head_module = core_model.lm_head # lm_head_module.register_forward_pre_hook(timing_pre_hook) # lm_head_module.register_forward_hook(timing_post_hook) # registered_modules.append(lm_head_module) # print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).") # if not registered_modules: # print_master("Warning: No major modules found for hook registration. Check model architecture.") # return registered_modules # def pad_dataset_to_divisible(dataset, world_size): # num_samples = len(dataset) # if num_samples % world_size == 0: # return dataset, num_samples # num_to_add = world_size - (num_samples % world_size) # padded_size = num_samples + num_to_add # padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) # padded_dataset = concatenate_datasets([dataset, padding_data]) # return padded_dataset, padded_size # def encode_embeddings( # model: MMEBModel, # loader: DataLoader, # training_args: TrainingArguments, # model_args: ModelArguments, # full_dataset: Dataset, # encode_side: str, # description: str = "Encoding" # ) -> tuple[np.ndarray, list, list]: # Added list to return type for batch_stats # """ # Encodes embeddings for a given dataset using the model, handling both standard and # late-interaction models in a DDP-safe manner. # """ # local_rank = dist.get_rank() if dist.is_initialized() else 0 # world_size = dist.get_world_size() if dist.is_initialized() else 1 # # Check if the model is a late-interaction type # is_late_interaction = (model_args.model_backbone == COLPALI) # local_embeds = [] # local_gt_infos = [] # local_max_len = 0 # # --- New: List to store statistics for each batch --- # batch_stats_list = [] # model.eval() # # Register hooks for the model once per encode_embeddings call # registered_hooks = register_model_hooks(model) # with torch.no_grad(): # for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): # # --- Reset statistics for each inference pass --- # timing_info.clear() # token_info["vision_tokens"] = 0 # token_info["text_input_tokens"] = 0 # token_info["text_output_tokens"] = 0 # token_info["total_llm_input_tokens"] = 0 # inputs = batch_to_device(inputs, training_args.device) # current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else 1 # Determine actual batch size # with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): # start_inference_time = time.time() # if encode_side == "qry": # output = model(qry=inputs) # reps = output["qry_reps"].detach() # local_gt_infos.extend(dataset_info) # else: # output = model(tgt=inputs) # reps = output["tgt_reps"].detach() # local_gt_infos.extend([info["cand_name"] for info in dataset_info]) # end_inference_time = time.time() # # --- Update total LLM input tokens after the model call --- # if 'input_ids' in inputs and inputs['input_ids'] is not None: # # `inputs['input_ids'].shape[1]` gives the sequence length, # # which is the number of tokens per item in the batch. # # To get total tokens for the batch, multiply by batch size. # token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1] # # Approximation for text_input_tokens # token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"] # token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"]) # Ensure not negative # # --- Collect and Store Batch Statistics --- # batch_inference_time = end_inference_time - start_inference_time # current_batch_stats = { # "batch_size": current_batch_size, # "total_inference_time_seconds": batch_inference_time, # "module_inference_times": {}, # "token_counts": { # "visual_tokens": token_info["vision_tokens"], # "language_input_tokens_raw": token_info["text_input_tokens"], # "llm_total_input_tokens": token_info["total_llm_input_tokens"], # "language_output_tokens": token_info["text_output_tokens"], # } # } # # Calculate and store module timings for the current batch # for module_obj in registered_hooks: # module_id = id(module_obj) # module_name = module_obj.__class__.__name__ # times = timing_info.get(module_id, []) # durations = [] # pre_times = {} # for t, event_type, _ in times: # if event_type == 'pre': # pre_times[module_id] = t # elif event_type == 'post' and module_id in pre_times: # duration = t - pre_times.pop(module_id) # durations.append(duration) # if durations: # current_batch_stats["module_inference_times"][module_name] = { # "total": sum(durations), # "count": len(durations), # "avg": sum(durations) / len(durations) # } # else: # current_batch_stats["module_inference_times"][module_name] = { # "total": 0.0, # "count": 0, # "avg": 0.0 # } # batch_stats_list.append(current_batch_stats) # Append the stats for this batch # # --- Print Inference Timing and Token Statistics per Batch (Optional, for debugging) --- # print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---") # print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds") # print_rank("--- Module Inference Timing Statistics ---") # for module_name, stats in current_batch_stats["module_inference_times"].items(): # print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s") # print_rank("--- Token Count Statistics ---") # print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}") # print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}") # print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}") # print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}") # if is_late_interaction and reps.dim() == 3: # local_max_len = max(local_max_len, reps.shape[1]) # local_embeds.append(reps) # if not local_embeds: # # Handle cases where a rank gets no data # return np.array([]), [], [] # Return empty list for batch_stats_list as well # # === DDP Synchronization and Padding for Late-Interaction Models === # if is_late_interaction: # if dist.is_initialized(): # # 1. Find the global maximum sequence length across all ranks # local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device) # dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX) # global_max_len = local_max_len_tensor.item() # else: # global_max_len = local_max_len # # 2. Pad all local embeddings to the global max length # padded_embeds = [] # for reps_batch in local_embeds: # if reps_batch.dim() == 3: # B, L, H = reps_batch.shape # padding_size = global_max_len - L # padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0) # padded_embeds.append(padded_batch) # else: # Should not happen if model is consistently late-interaction # padded_embeds.append(reps_batch) # embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous() # else: # Standard dense models # embeds_tensor = torch.cat(local_embeds, dim=0).contiguous() # # === Gather embeddings and keys from all ranks === # if dist.is_initialized() and full_dataset.num_rows >= world_size: # print_master(f"Gathering {encode_side} embeddings across all ranks...") # # Use the more efficient all_gather_into_tensor for tensors # output_shape = list(embeds_tensor.shape) # output_shape[0] = full_dataset.num_rows # embeds_tensor = embeds_tensor.to(training_args.device) # gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) # dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor) # final_embeddings = gathered_embeds_tensor.cpu().float().numpy() # # Gather metadata, for which all_gather_object is appropriate # gathered_gt_infos = [None for _ in range(world_size)] # dist.all_gather_object(gathered_gt_infos, local_gt_infos) # all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys] # # --- New: Gather batch_stats_list from all ranks --- # gathered_batch_stats = [None for _ in range(world_size)] # dist.all_gather_object(gathered_batch_stats, batch_stats_list) # all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats] # else: # all_gt_infos = local_gt_infos # final_embeddings = embeds_tensor.cpu().float().numpy() # all_batch_stats = batch_stats_list # If not DDP, just use local list # return final_embeddings, all_gt_infos, all_batch_stats # def main(): # if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): # dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) # local_rank = dist.get_rank() if dist.is_initialized() else 0 # world_size = dist.get_world_size() if dist.is_initialized() else 1 # # DEBUG PRINTS for Distributed Setup # print_master("Distributed init debug info:") # print_master(f"RANK: {os.environ.get('RANK')}") # print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}") # print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}") # print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}") # print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}") # if dist.is_initialized(): # print_rank(f"dist.get_rank(): {dist.get_rank()}") # print_rank(f"dist.get_world_size(): {dist.get_world_size()}") # for arg in sys.argv: # if arg.startswith("--local-rank="): # rank = arg.split("=")[1] # sys.argv.remove(arg) # sys.argv.append('--local_rank') # sys.argv.append(rank) # parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) # model_args, data_args, training_args = parser.parse_args_into_dataclasses() # model_args: ModelArguments # data_args: DataArguments # training_args: TrainingArguments # os.makedirs(data_args.encode_output_path, exist_ok=True) # # --- Model Loading --- # hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # if not getattr(model_args, "model_backbone", None): # model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) # setattr(model_args, 'model_backbone', model_backbone) # setattr(training_args, 'model_backbone', model_backbone) # print_master(f'Model Backbone: {model_args.model_backbone}') # # --- DDP-Safe Model Loading --- # # Step 1: Only the master process (rank 0) downloads the model. # if local_rank == 0: # processor = load_processor(model_args, data_args) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...") # # Step 2: All processes wait here. The non-master processes will pause # # until the master process (rank 0) finishes downloading and exits this barrier. # if torch.distributed.is_initialized(): # torch.distributed.barrier() # # Step 3: Now that the model is cached, the non-master processes load it from the local cache. # if local_rank != 0: # print_rank(f"Loading the model from cache...") # processor = load_processor(model_args, data_args) # time.sleep(random.randint(2 * local_rank, 3 * local_rank)) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # model.eval() # model = model.to(training_args.device, dtype=torch.bfloat16) # with open(data_args.dataset_config, 'r') as yaml_file: # dataset_configs = yaml.safe_load(yaml_file) # # --- Main Evaluation Loop --- # for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()): # # Initialize task-level statistics accumulators # task_total_stats = { # "total_inference_time_seconds": 0.0, # "module_inference_times": { # "Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0}, # "PatchMerger": {"total": 0.0, "count": 0}, # "Qwen2VLModel": {"total": 0.0, "count": 0}, # "Linear": {"total": 0.0, "count": 0}, # }, # "token_counts": { # "visual_tokens": 0, # "language_input_tokens_raw": 0, # "llm_total_input_tokens": 0, # "language_output_tokens": 0, # }, # "data_point_count": 0 # Number of image-text pairs processed # } # if dist.is_initialized(): # dist.barrier() # print_master(f"\n--- Evaluating {dataset_name} ---") # query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry") # cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt") # dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") # # New: Define path for inference statistics output # inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_inference_stats.json") # do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path) # do_cand = not os.path.exists(cand_embed_path) # if do_query or do_cand: # if data_args.data_basedir is not None: # # Construct full paths for data files if --data_basedir is provided # for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: # if data_args.data_basedir and task_config.get(key): # task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) # full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) # full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) # eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # # Pad datasets to be divisible by world_size before splitting # if dist.is_initialized(): # padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size) # padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size) # eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size) # eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size) # else: # padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # # --- 1. Compute Query Embeddings --- # if do_query: # print_master("Encoding queries...") # eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") # eval_qry_loader = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers) # # Modified: capture batch_stats_list # query_embeds, gt_infos, qry_batch_stats = encode_embeddings(model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries for {dataset_name}") # # Accumulate query statistics # for batch_stat in qry_batch_stats: # batch_size = batch_stat["batch_size"] # task_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"] # for module_name, module_stats in batch_stat["module_inference_times"].items(): # if module_name in task_total_stats["module_inference_times"]: # task_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"] # task_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"] # count here is per-module-call, not per-item # # Token counts are per-item for 'visual_tokens', 'llm_total_input_tokens'. # # For 'text_input_tokens', it's calculated based on sequence length, so it's also total tokens in the batch. # # We need to average it later by total data_point_count. # # However, your current hook logic collects the token count for a *single* item if batch_size=1, # # or for the full batch if it processes sequentially. # # Let's assume the `token_info` collected by hooks reflects the *current batch*. # # To get per-item average later, we sum up totals and divide by total data points. # task_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size # Corrected assumption: visual_tokens are per-item, multiplied by batch_size to get total for batch # task_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size # Corrected # task_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size # Corrected # task_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size # Corrected # task_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items # query_embeds = query_embeds[:len(full_eval_qry_dataset)] # gt_infos = gt_infos[:len(full_eval_qry_dataset)] # if local_rank == 0: # with open(query_embed_path, 'wb') as f: # pickle.dump(query_embeds, f) # with open(dataset_info_path, 'w') as f: # for info in gt_infos: # f.write(json.dumps(info) + '\n') # print_master(f"Saved query embeddings to {query_embed_path}") # if dist.is_initialized(): # dist.barrier() # # --- 2. Compute Candidate Embeddings --- # if do_cand: # print_master("Encoding candidates...") # eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") # eval_cand_loader = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) # # Modified: capture batch_stats_list # cand_embeds, all_cand_ids, cand_batch_stats = encode_embeddings(model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates for {dataset_name}") # # Accumulate candidate statistics (similar logic as query) # for batch_stat in cand_batch_stats: # batch_size = batch_stat["batch_size"] # task_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"] # for module_name, module_stats in batch_stat["module_inference_times"].items(): # if module_name in task_total_stats["module_inference_times"]: # task_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"] # task_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"] # task_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size # task_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size # task_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size # task_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size # task_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items # cand_embeds = cand_embeds[:len(full_eval_cand_dataset)] # all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)] # if local_rank == 0: # cand_embed_dict = {cand_id: embed for cand_id, embed in zip(all_cand_ids, cand_embeds)} # with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f) # print_master(f"Saved candidate embeddings to {cand_embed_path}") # if dist.is_initialized(): # dist.barrier() # # --- New: Calculate and Save Task-level Inference Statistics (on master rank only) --- # if local_rank == 0: # if task_total_stats["data_point_count"] > 0: # final_task_stats = { # "task_name": dataset_name, # "data_point_count": task_total_stats["data_point_count"], # "inference_times": { # "total_inference_time_seconds": task_total_stats["total_inference_time_seconds"], # "avg_inference_time_per_item_seconds": task_total_stats["total_inference_time_seconds"] / task_total_stats["data_point_count"], # "module_average_times_per_call": {}, # Average per call to the module # "module_total_times_seconds": {}, # Total time spent in the module # "module_calls_count": {}, # Number of times the module was called # }, # "token_counts": { # "total_visual_tokens": task_total_stats["token_counts"]["visual_tokens"], # "avg_visual_tokens_per_item": task_total_stats["token_counts"]["visual_tokens"] / task_total_stats["data_point_count"], # "total_language_input_tokens_raw": task_total_stats["token_counts"]["language_input_tokens_raw"], # "avg_language_input_tokens_raw_per_item": task_total_stats["token_counts"]["language_input_tokens_raw"] / task_total_stats["data_point_count"], # "total_llm_total_input_tokens": task_total_stats["token_counts"]["llm_total_input_tokens"], # "avg_llm_total_input_tokens_per_item": task_total_stats["token_counts"]["llm_total_input_tokens"] / task_total_stats["data_point_count"], # "total_language_output_tokens": task_total_stats["token_counts"]["language_output_tokens"], # "avg_language_output_tokens_per_item": task_total_stats["token_counts"]["language_output_tokens"] / task_total_stats["data_point_count"], # } # } # for module_name, stats in task_total_stats["module_inference_times"].items(): # final_task_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"] # final_task_stats["inference_times"]["module_calls_count"][module_name] = stats["count"] # if stats["count"] > 0: # final_task_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"] # else: # final_task_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0 # with open(inference_stats_path, 'w', encoding='utf-8') as f: # json.dump(final_task_stats, f, ensure_ascii=False, indent=4) # print_master(f"Inference statistics for {dataset_name} saved to: {inference_stats_path}") # else: # print_master(f"No data processed for {dataset_name}, skipping inference statistics output.") # # --- 3. Compute Scores (on master rank only) --- # score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") # if os.path.exists(score_path): # try: # with open(score_path, "r") as f: # score_dict = json.load(f) # print_master(f"Score of {dataset_name} (loaded from previous run): {score_path}") # formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} # print_master(formatted) # # No `continue` here, as we want to ensure other files are processed/generated # except Exception as e: # print_master(f"Failed to load score for {dataset_name}, proceeding to recompute. Error: {e}") # # Proceed with score computation if not loaded or failed to load # with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f) # with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f) # gt_infos = [json.loads(l) for l in open(dataset_info_path)] # pred_dicts = [] # rank_against_all_candidates = task_config.get("eval_type", "global") == "global" # if rank_against_all_candidates: # cand_keys = list(cand_embed_dict.keys()) # cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys]) # # Handle late-interaction scoring # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] # qry_embed = torch.from_numpy(qry_embeds) # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] # scores = processor.score(qry_embed, cand_embeds, batch_size=64) # use ColPali score function # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist() # else: # Dense # cosine_scores = np.dot(qry_embeds, cand_embeds.T) # ranked_candids = np.argsort(-cosine_scores, axis=1) # for qid, (ranked_candid, gt_info) in tqdm(enumerate(zip(ranked_candids, gt_infos)), desc=f"Calculating scores for {dataset_name}"): # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None # assert rel_scores is None or len(rel_docids) == len(rel_scores) # pred_dicts.append({ # "prediction": [cand_keys[i] for i in ranked_candid], # "label": rel_docids, # "rel_scores": rel_scores, # }) # else: # for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"): # cand_embeds = np.stack([cand_embed_dict[key] for key in gt_info["cand_names"]]) # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] # qry_embed = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] # scores = processor.score(qry_embed, cand_embeds, batch_size=1024) # use ColPali score function # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()[0] # else: # cosine_score = np.dot(qry_embed, cand_embeds.T) # ranked_candids = np.argsort(-cosine_score) # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None # assert rel_scores is None or len(rel_docids) == len(rel_scores) # pred_dicts.append({ # "prediction": [gt_info["cand_names"][i] for i in ranked_candids], # "label": rel_docids, # "rel_scores": rel_scores, # }) # score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") # pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl") # metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] # metrics = RankingMetrics(metrics_to_report) # score_dict = metrics.evaluate(pred_dicts) # formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} # score_dict["num_pred"] = len(pred_dicts) # score_dict["num_data"] = len(gt_infos) # print_master(f"Score of {dataset_name}:") # print_master(formatted) # print_master(f"Outputting final score to: {score_path}") # with open(score_path, "w") as f: # json.dump(score_dict, f, indent=4) # with open(pred_path, "w") as f: # for pred in pred_dicts: # f.write(json.dumps(pred) + '\n') # if __name__ == "__main__": # main() ########################################################################################## # ######################################################################################## # #分query和cand进行统计 # import datetime # import logging # import json # import random # import time # import numpy as np # import os # import pickle # import sys # import torch # import torch.distributed as dist # import torch.nn.functional as F # import yaml # import transformers # from torch.utils.data import DataLoader # from tqdm import tqdm # from transformers import HfArgumentParser, AutoConfig, AutoTokenizer # from datasets import Dataset, concatenate_datasets # from datasets.distributed import split_dataset_by_node # from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src # from src.arguments import ModelArguments, DataArguments, TrainingArguments # from src.data.collator.eval_collator import MultimodalEvalDataCollator # from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset # from src.eval_utils.metrics import RankingMetrics # from src.model.model import MMEBModel # from src.model.processor import get_backbone_name, load_processor, COLPALI # from src.utils import batch_to_device, print_rank, print_master # logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') # logger = logging.getLogger(__name__) # # --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) --- # timing_info = {} # token_info = { # "vision_tokens": 0, # "text_input_tokens": 0, # Refers to the original text token count # "text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0. # "total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text) # } # # --- Hook Functions Definition --- # def timing_pre_hook(module, input): # module_id = id(module) # if module_id not in timing_info: # timing_info[module_id] = [] # timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__)) # def timing_post_hook(module, input, output): # module_id = id(module) # if module_id not in timing_info: # # print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})") # return # timing_info[module_id].append((time.time(), 'post', module.__class__.__name__)) # # Collect vision token count (only from Vision Transformer module's post hook) # module_name = module.__class__.__name__ # if "vision" in module_name.lower() and "transformer" in module_name.lower(): # if isinstance(output, torch.Tensor): # token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim) # elif hasattr(output, 'last_hidden_state'): # token_info["vision_tokens"] = output.last_hidden_state.shape[1] # def register_model_hooks(model): # registered_modules = [] # core_model = model # # print_master(f"DEBUG: Initial model type in register_model_hooks: {type(model)}") # if hasattr(model, 'encoder') and model.encoder is not None: # # print_master(f"DEBUG: model has 'encoder' attribute. Type of model.encoder: {type(model.encoder)}") # # 使用从 'src' 路径导入的 Qwen2VLForConditionalGeneration 进行检查 # if isinstance(model.encoder, _Qwen2VLForConditionalGeneration_src): # # print_master("Detected MMEBModel structure, registering hooks on model.encoder's sub-modules.") # core_model = model.encoder # else: # print_master(f"WARNING: model.encoder is not an instance of _Qwen2VLForConditionalGeneration_src. Its type is {type(model.encoder)}. Hooks will be registered on top-level model if applicable.") # else: # print_master("WARNING: Model structure does not have an 'encoder' attribute. Registering hooks directly on top-level modules.") # # Vision module # if hasattr(core_model, 'visual') and core_model.visual is not None: # vision_module = core_model.visual # vision_module.register_forward_pre_hook(timing_pre_hook) # vision_module.register_forward_hook(timing_post_hook) # registered_modules.append(vision_module) # print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).") # # Merger module (if inside visual) - it's part of the vision component # if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None: # merger_module = core_model.visual.merger # merger_module.register_forward_pre_hook(timing_pre_hook) # merger_module.register_forward_hook(timing_post_hook) # registered_modules.append(merger_module) # print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).") # # Language model body # if hasattr(core_model, 'model') and core_model.model is not None: # llm_main_module = core_model.model # llm_main_module.register_forward_pre_hook(timing_pre_hook) # llm_main_module.register_forward_hook(timing_post_hook) # registered_modules.append(llm_main_module) # print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).") # # LM Head # if hasattr(core_model, 'lm_head') and core_model.lm_head is not None: # lm_head_module = core_model.lm_head # lm_head_module.register_forward_pre_hook(timing_pre_hook) # lm_head_module.register_forward_hook(timing_post_hook) # registered_modules.append(lm_head_module) # print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).") # if not registered_modules: # print_master("Warning: No major modules found for hook registration. Check model architecture.") # return registered_modules # def pad_dataset_to_divisible(dataset, world_size): # num_samples = len(dataset) # if num_samples % world_size == 0: # return dataset, num_samples # num_to_add = world_size - (num_samples % world_size) # padded_size = num_samples + num_to_add # padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) # padded_dataset = concatenate_datasets([dataset, padding_data]) # return padded_dataset, padded_size # def encode_embeddings( # model: MMEBModel, # loader: DataLoader, # training_args: TrainingArguments, # model_args: ModelArguments, # full_dataset: Dataset, # encode_side: str, # description: str = "Encoding" # ) -> tuple[np.ndarray, list, list]: # Added list to return type for batch_stats # """ # Encodes embeddings for a given dataset using the model, handling both standard and # late-interaction models in a DDP-safe manner. # """ # local_rank = dist.get_rank() if dist.is_initialized() else 0 # world_size = dist.get_world_size() if dist.is_initialized() else 1 # # Check if the model is a late-interaction type # is_late_interaction = (model_args.model_backbone == COLPALI) # local_embeds = [] # local_gt_infos = [] # local_max_len = 0 # # --- New: List to store statistics for each batch --- # batch_stats_list = [] # model.eval() # # Register hooks for the model once per encode_embeddings call # registered_hooks = register_model_hooks(model) # with torch.no_grad(): # for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): # # --- Reset statistics for each inference pass --- # timing_info.clear() # token_info["vision_tokens"] = 0 # token_info["text_input_tokens"] = 0 # token_info["text_output_tokens"] = 0 # token_info["total_llm_input_tokens"] = 0 # inputs = batch_to_device(inputs, training_args.device) # current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else 1 # Determine actual batch size # with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): # start_inference_time = time.time() # if encode_side == "qry": # output = model(qry=inputs) # reps = output["qry_reps"].detach() # local_gt_infos.extend(dataset_info) # else: # output = model(tgt=inputs) # reps = output["tgt_reps"].detach() # local_gt_infos.extend([info["cand_name"] for info in dataset_info]) # end_inference_time = time.time() # # --- Update total LLM input tokens after the model call --- # if 'input_ids' in inputs and inputs['input_ids'] is not None: # # `inputs['input_ids'].shape[1]` gives the sequence length, # # which is the number of tokens per item in the batch. # # To get total tokens for the batch, multiply by batch size. # token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1] # # Approximation for text_input_tokens # token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"] # token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"]) # Ensure not negative # # --- Collect and Store Batch Statistics --- # batch_inference_time = end_inference_time - start_inference_time # current_batch_stats = { # "batch_size": current_batch_size, # "total_inference_time_seconds": batch_inference_time, # "module_inference_times": {}, # "token_counts": { # "visual_tokens": token_info["vision_tokens"], # "language_input_tokens_raw": token_info["text_input_tokens"], # "llm_total_input_tokens": token_info["total_llm_input_tokens"], # "language_output_tokens": token_info["text_output_tokens"], # } # } # # Calculate and store module timings for the current batch # for module_obj in registered_hooks: # module_id = id(module_obj) # module_name = module_obj.__class__.__name__ # times = timing_info.get(module_id, []) # durations = [] # pre_times = {} # for t, event_type, _ in times: # if event_type == 'pre': # pre_times[module_id] = t # elif event_type == 'post' and module_id in pre_times: # duration = t - pre_times.pop(module_id) # durations.append(duration) # if durations: # current_batch_stats["module_inference_times"][module_name] = { # "total": sum(durations), # "count": len(durations), # "avg": sum(durations) / len(durations) # } # else: # current_batch_stats["module_inference_times"][module_name] = { # "total": 0.0, # "count": 0, # "avg": 0.0 # } # batch_stats_list.append(current_batch_stats) # Append the stats for this batch # # --- Print Inference Timing and Token Statistics per Batch (Optional, for debugging) --- # print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---") # print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds") # print_rank("--- Module Inference Timing Statistics ---") # for module_name, stats in current_batch_stats["module_inference_times"].items(): # print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s") # print_rank("--- Token Count Statistics ---") # print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}") # print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}") # print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}") # print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}") # if is_late_interaction and reps.dim() == 3: # local_max_len = max(local_max_len, reps.shape[1]) # local_embeds.append(reps) # if not local_embeds: # # Handle cases where a rank gets no data # return np.array([]), [], [] # Return empty list for batch_stats_list as well # # === DDP Synchronization and Padding for Late-Interaction Models === # if is_late_interaction: # if dist.is_initialized(): # # 1. Find the global maximum sequence length across all ranks # local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device) # dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX) # global_max_len = local_max_len_tensor.item() # else: # global_max_len = local_max_len # # 2. Pad all local embeddings to the global max length # padded_embeds = [] # for reps_batch in local_embeds: # if reps_batch.dim() == 3: # B, L, H = reps_batch.shape # padding_size = global_max_len - L # padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0) # padded_embeds.append(padded_batch) # else: # Should not happen if model is consistently late-interaction # padded_embeds.append(reps_batch) # embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous() # else: # Standard dense models # embeds_tensor = torch.cat(local_embeds, dim=0).contiguous() # # === Gather embeddings and keys from all ranks === # if dist.is_initialized() and full_dataset.num_rows >= world_size: # print_master(f"Gathering {encode_side} embeddings across all ranks...") # # Use the more efficient all_gather_into_tensor for tensors # output_shape = list(embeds_tensor.shape) # output_shape[0] = full_dataset.num_rows # embeds_tensor = embeds_tensor.to(training_args.device) # gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) # dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor) # final_embeddings = gathered_embeds_tensor.cpu().float().numpy() # # Gather metadata, for which all_gather_object is appropriate # gathered_gt_infos = [None for _ in range(world_size)] # dist.all_gather_object(gathered_gt_infos, local_gt_infos) # all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys] # # --- New: Gather batch_stats_list from all ranks --- # gathered_batch_stats = [None for _ in range(world_size)] # dist.all_gather_object(gathered_batch_stats, batch_stats_list) # all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats] # else: # all_gt_infos = local_gt_infos # final_embeddings = embeds_tensor.cpu().float().numpy() # all_batch_stats = batch_stats_list # If not DDP, just use local list # return final_embeddings, all_gt_infos, all_batch_stats # def main(): # if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): # dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) # local_rank = dist.get_rank() if dist.is_initialized() else 0 # world_size = dist.get_world_size() if dist.is_initialized() else 1 # # DEBUG PRINTS for Distributed Setup # print_master("Distributed init debug info:") # print_master(f"RANK: {os.environ.get('RANK')}") # print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}") # print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}") # print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}") # print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}") # if dist.is_initialized(): # print_rank(f"dist.get_rank(): {dist.get_rank()}") # print_rank(f"dist.get_world_size(): {dist.get_world_size()}") # for arg in sys.argv: # if arg.startswith("--local-rank="): # rank = arg.split("=")[1] # sys.argv.remove(arg) # sys.argv.append('--local_rank') # sys.argv.append(rank) # parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) # model_args, data_args, training_args = parser.parse_args_into_dataclasses() # model_args: ModelArguments # data_args: DataArguments # training_args: TrainingArguments # os.makedirs(data_args.encode_output_path, exist_ok=True) # # --- Model Loading --- # hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # if not getattr(model_args, "model_backbone", None): # model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) # setattr(model_args, 'model_backbone', model_backbone) # setattr(training_args, 'model_backbone', model_backbone) # print_master(f'Model Backbone: {model_args.model_backbone}') # # --- DDP-Safe Model Loading --- # # Step 1: Only the master process (rank 0) downloads the model. # if local_rank == 0: # processor = load_processor(model_args, data_args) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...") # # Step 2: All processes wait here. The non-master processes will pause # # until the master process (rank 0) finishes downloading and exits this barrier. # if torch.distributed.is_initialized(): # torch.distributed.barrier() # # Step 3: Now that the model is cached, the non-master processes load it from the local cache. # if local_rank != 0: # print_rank(f"Loading the model from cache...") # processor = load_processor(model_args, data_args) # time.sleep(random.randint(2 * local_rank, 3 * local_rank)) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # model.eval() # model = model.to(training_args.device, dtype=torch.bfloat16) # with open(data_args.dataset_config, 'r') as yaml_file: # dataset_configs = yaml.safe_load(yaml_file) # # --- Main Evaluation Loop --- # for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()): # # Initialize task-level statistics accumulators for QUERY # query_total_stats = { # "total_inference_time_seconds": 0.0, # "module_inference_times": { # "Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0}, # "PatchMerger": {"total": 0.0, "count": 0}, # "Qwen2VLModel": {"total": 0.0, "count": 0}, # "Linear": {"total": 0.0, "count": 0}, # }, # "token_counts": { # "visual_tokens": 0, # "language_input_tokens_raw": 0, # "llm_total_input_tokens": 0, # "language_output_tokens": 0, # }, # "data_point_count": 0 # Number of image-text pairs processed # } # # Initialize task-level statistics accumulators for CANDIDATE # cand_total_stats = { # "total_inference_time_seconds": 0.0, # "module_inference_times": { # "Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0}, # "PatchMerger": {"total": 0.0, "count": 0}, # "Qwen2VLModel": {"total": 0.0, "count": 0}, # "Linear": {"total": 0.0, "count": 0}, # }, # "token_counts": { # "visual_tokens": 0, # "language_input_tokens_raw": 0, # "llm_total_input_tokens": 0, # "language_output_tokens": 0, # }, # "data_point_count": 0 # Number of image-text pairs processed # } # if dist.is_initialized(): # dist.barrier() # print_master(f"\n--- Evaluating {dataset_name} ---") # query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry") # cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt") # dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") # # New: Define distinct paths for query and candidate inference statistics output # query_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_inference_stats.json") # cand_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_inference_stats.json") # do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path) # do_cand = not os.path.exists(cand_embed_path) # if do_query or do_cand: # if data_args.data_basedir is not None: # # Construct full paths for data files if --data_basedir is provided # for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: # if data_args.data_basedir and task_config.get(key): # task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) # full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) # full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) # eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # # Pad datasets to be divisible by world_size before splitting # if dist.is_initialized(): # padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size) # padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size) # eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size) # eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size) # else: # padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # # --- 1. Compute Query Embeddings --- # if do_query: # print_master("Encoding queries...") # eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") # eval_qry_loader = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers) # # Modified: capture batch_stats_list # query_embeds, gt_infos, qry_batch_stats = encode_embeddings(model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries for {dataset_name}") # # Accumulate query statistics # for batch_stat in qry_batch_stats: # batch_size = batch_stat["batch_size"] # query_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"] # for module_name, module_stats in batch_stat["module_inference_times"].items(): # if module_name in query_total_stats["module_inference_times"]: # query_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"] # query_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"] # query_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size # query_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size # query_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size # query_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size # query_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items # query_embeds = query_embeds[:len(full_eval_qry_dataset)] # gt_infos = gt_infos[:len(full_eval_qry_dataset)] # if local_rank == 0: # with open(query_embed_path, 'wb') as f: # pickle.dump(query_embeds, f) # with open(dataset_info_path, 'w') as f: # for info in gt_infos: # f.write(json.dumps(info) + '\n') # print_master(f"Saved query embeddings to {query_embed_path}") # # Save query-specific inference statistics # if query_total_stats["data_point_count"] > 0: # final_query_stats = { # "task_name": dataset_name, # "encode_side": "query", # "data_point_count": query_total_stats["data_point_count"], # "inference_times": { # "total_inference_time_seconds": query_total_stats["total_inference_time_seconds"], # "avg_inference_time_per_item_seconds": query_total_stats["total_inference_time_seconds"] / query_total_stats["data_point_count"], # "module_average_times_per_call": {}, # "module_total_times_seconds": {}, # "module_calls_count": {}, # }, # "token_counts": { # "total_visual_tokens": query_total_stats["token_counts"]["visual_tokens"], # "avg_visual_tokens_per_item": query_total_stats["token_counts"]["visual_tokens"] / query_total_stats["data_point_count"], # "total_language_input_tokens_raw": query_total_stats["token_counts"]["language_input_tokens_raw"], # "avg_language_input_tokens_raw_per_item": query_total_stats["token_counts"]["language_input_tokens_raw"] / query_total_stats["data_point_count"], # "total_llm_total_input_tokens": query_total_stats["token_counts"]["llm_total_input_tokens"], # "avg_llm_total_input_tokens_per_item": query_total_stats["token_counts"]["llm_total_input_tokens"] / query_total_stats["data_point_count"], # "total_language_output_tokens": query_total_stats["token_counts"]["language_output_tokens"], # "avg_language_output_tokens_per_item": query_total_stats["token_counts"]["language_output_tokens"] / query_total_stats["data_point_count"], # } # } # for module_name, stats in query_total_stats["module_inference_times"].items(): # final_query_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"] # final_query_stats["inference_times"]["module_calls_count"][module_name] = stats["count"] # if stats["count"] > 0: # final_query_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"] # else: # final_query_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0 # with open(query_inference_stats_path, 'w', encoding='utf-8') as f: # json.dump(final_query_stats, f, ensure_ascii=False, indent=4) # print_master(f"Query inference statistics for {dataset_name} saved to: {query_inference_stats_path}") # else: # print_master(f"No query data processed for {dataset_name}, skipping query inference statistics output.") # if dist.is_initialized(): # dist.barrier() # # --- 2. Compute Candidate Embeddings --- # if do_cand: # print_master("Encoding candidates...") # eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") # eval_cand_loader = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) # # Modified: capture batch_stats_list # cand_embeds, all_cand_ids, cand_batch_stats = encode_embeddings(model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates for {dataset_name}") # # Accumulate candidate statistics (similar logic as query) # for batch_stat in cand_batch_stats: # batch_size = batch_stat["batch_size"] # cand_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"] # for module_name, module_stats in batch_stat["module_inference_times"].items(): # if module_name in cand_total_stats["module_inference_times"]: # cand_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"] # cand_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"] # cand_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size # cand_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size # cand_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size # cand_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size # cand_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items # cand_embeds = cand_embeds[:len(full_eval_cand_dataset)] # all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)] # if local_rank == 0: # cand_embed_dict = {cand_id: embed for cand_id, embed in zip(all_cand_ids, cand_embeds)} # with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f) # print_master(f"Saved candidate embeddings to {cand_embed_path}") # # Save candidate-specific inference statistics # if cand_total_stats["data_point_count"] > 0: # final_cand_stats = { # "task_name": dataset_name, # "encode_side": "candidate", # "data_point_count": cand_total_stats["data_point_count"], # "inference_times": { # "total_inference_time_seconds": cand_total_stats["total_inference_time_seconds"], # "avg_inference_time_per_item_seconds": cand_total_stats["total_inference_time_seconds"] / cand_total_stats["data_point_count"], # "module_average_times_per_call": {}, # "module_total_times_seconds": {}, # "module_calls_count": {}, # }, # "token_counts": { # "total_visual_tokens": cand_total_stats["token_counts"]["visual_tokens"], # "avg_visual_tokens_per_item": cand_total_stats["token_counts"]["visual_tokens"] / cand_total_stats["data_point_count"], # "total_language_input_tokens_raw": cand_total_stats["token_counts"]["language_input_tokens_raw"], # "avg_language_input_tokens_raw_per_item": cand_total_stats["token_counts"]["language_input_tokens_raw"] / cand_total_stats["data_point_count"], # "total_llm_total_input_tokens": cand_total_stats["token_counts"]["llm_total_input_tokens"], # "avg_llm_total_input_tokens_per_item": cand_total_stats["token_counts"]["llm_total_input_tokens"] / cand_total_stats["data_point_count"], # "total_language_output_tokens": cand_total_stats["token_counts"]["language_output_tokens"], # "avg_language_output_tokens_per_item": cand_total_stats["token_counts"]["language_output_tokens"] / cand_total_stats["data_point_count"], # } # } # for module_name, stats in cand_total_stats["module_inference_times"].items(): # final_cand_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"] # final_cand_stats["inference_times"]["module_calls_count"][module_name] = stats["count"] # if stats["count"] > 0: # final_cand_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"] # else: # final_cand_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0 # with open(cand_inference_stats_path, 'w', encoding='utf-8') as f: # json.dump(final_cand_stats, f, ensure_ascii=False, indent=4) # print_master(f"Candidate inference statistics for {dataset_name} saved to: {cand_inference_stats_path}") # else: # print_master(f"No candidate data processed for {dataset_name}, skipping candidate inference statistics output.") # if dist.is_initialized(): # dist.barrier() # # --- 3. Compute Scores (on master rank only) --- # score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") # #################################################################################### # pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl") # score_detail_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_details.jsonl") # 新文件,存相似度分数 # def append_score_detail(score_detail_list, qid, ranked_indices, score_vector, cand_ids, labels): # """追加一个 query 的候选分数详情""" # score_detail_list.append({ # "qid": int(qid), # "cand_scores": [ # {"cand_id": str(cand_ids[i]), "score": float(score_vector[i])} # for i in ranked_indices # ], # "label": labels # }) # #################################################################################### # if local_rank == 0: # if os.path.exists(score_path): # try: # with open(score_path, "r") as f: # score_dict = json.load(f) # print_master(f"Score of {dataset_name} (loaded from previous run): {score_path}") # formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} # print_master(formatted) # # No `continue` here, as we want to ensure other files are processed/generated # except Exception as e: # print_master(f"Failed to load score for {dataset_name}, proceeding to recompute. Error: {e}") # # Proceed with score computation if not loaded or failed to load # with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f) # with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f) # gt_infos = [json.loads(l) for l in open(dataset_info_path)] # pred_dicts = [] # score_detail_dicts = []################################### # rank_against_all_candidates = task_config.get("eval_type", "global") == "global" # # if rank_against_all_candidates: # # cand_keys = list(cand_embed_dict.keys()) # # cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys]) # # # Handle late-interaction scoring # # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] # # qry_embed = torch.from_numpy(qry_embeds) # # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] # # scores = processor.score(qry_embed, cand_embeds, batch_size=64) # use ColPali score function # # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist() # # scores = scores.cpu().numpy() # # else: # Dense # # cosine_scores = np.dot(qry_embeds, cand_embeds.T) # # ranked_candids = np.argsort(-cosine_scores, axis=1) # ##################################################### # if rank_against_all_candidates: # cand_keys = list(cand_embed_dict.keys()) # cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys]) # if qry_embeds.ndim == 3: # Late-interaction # qry_embed_t = torch.from_numpy(qry_embeds) # cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds] # sim_matrix = processor.score(qry_embed_t, cand_embeds_t, batch_size=64).cpu().numpy() # [N_q, N_c] # else: # Dense # sim_matrix = np.dot(qry_embeds, cand_embeds.T) # [N_q, N_c] # ranked_all = np.argsort(-sim_matrix, axis=1) # ######################################################### # for qid, (ranked_candid, gt_info) in tqdm(enumerate(zip(ranked_candids, gt_infos)), desc=f"Calculating scores for {dataset_name}"): # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None # assert rel_scores is None or len(rel_docids) == len(rel_scores) # pred_dicts.append({ # "prediction": [cand_keys[i] for i in ranked_candid], # "label": rel_docids, # "rel_scores": rel_scores, # }) # ################################# 新增:详细相似度字典 # append_score_detail(score_detail_dicts, qid, ranked_indices, sim_matrix[qid], cand_keys, rel_docids) # ######################################## # # else: # # for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"): # # cand_embeds = np.stack([cand_embed_dict[key] for key in gt_info["cand_names"]]) # # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] # # qry_embed = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] # # scores = processor.score(qry_embed, cand_embeds, batch_size=1024) # use ColPali score function # # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()[0] # # else: # # cosine_score = np.dot(qry_embed, cand_embeds.T) # # ranked_candids = np.argsort(-cosine_score) # # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] # # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None # # assert rel_scores is None or len(rel_docids) == len(rel_scores) # # pred_dicts.append({ # # "prediction": [gt_info["cand_names"][i] for i in ranked_candids], # # "label": rel_docids, # # "rel_scores": rel_scores, # # }) # ####################################################################### # else: # 非全局 # for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"): # cand_ids_local = gt_info["cand_names"] # cand_embeds = np.stack([cand_embed_dict[key] for key in cand_ids_local]) # if qry_embeds.ndim == 3: # Late-interaction # qry_embed_t = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # [1, Lq, H] # cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds] # sim_vec = processor.score(qry_embed_t, cand_embeds_t, batch_size=1024).cpu().numpy()[0] # [N_c] # else: # Dense # sim_vec = np.dot(qry_embed, cand_embeds.T) # [N_c] # ranked_indices = np.argsort(-sim_vec) # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None # assert rel_scores is None or len(rel_docids) == len(rel_scores) # pred_dicts.append({ # "prediction": [cand_ids_local[i] for i in ranked_indices], # "label": rel_docids, # "rel_scores": rel_scores, # }) # # 新增:分数详情 # append_score_detail(score_detail_dicts, qid, ranked_indices, sim_vec, cand_ids_local, rel_docids) # ########################################## 保存预测和分数 # with open(score_detail_path, "w") as f: # 新增 # for detail in score_detail_dicts: # f.write(json.dumps(detail) + '\n') # print_master(f"Detailed score file saved to: {score_detail_path}") # metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] # metrics = RankingMetrics(metrics_to_report) # score_dict = metrics.evaluate(pred_dicts) # formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} # score_dict["num_pred"] = len(pred_dicts) # score_dict["num_data"] = len(gt_infos) # print_master(f"Score of {dataset_name}:") # print_master(formatted) # print_master(f"Outputting final score to: {score_path}") # with open(score_path, "w") as f: # json.dump(score_dict, f, indent=4) # with open(pred_path, "w") as f: # for pred in pred_dicts: # f.write(json.dumps(pred) + '\n') # #################################################################### # score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") # pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl") # metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] # metrics = RankingMetrics(metrics_to_report) # score_dict = metrics.evaluate(pred_dicts) # formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} # score_dict["num_pred"] = len(pred_dicts) # score_dict["num_data"] = len(gt_infos) # print_master(f"Score of {dataset_name}:") # print_master(formatted) # print_master(f"Outputting final score to: {score_path}") # with open(score_path, "w") as f: # json.dump(score_dict, f, indent=4) # with open(pred_path, "w") as f: # for pred in pred_dicts: # f.write(json.dumps(pred) + '\n') # if __name__ == '__main__': # main() ###################################################################################################### ####################################################################################################### #为了可视化把mask值也输出 ######################################################################################## #分query和cand进行统计 import datetime import logging import json import random import time import numpy as np import os import pickle import sys import torch import torch.distributed as dist import torch.nn.functional as F import yaml import transformers from torch.utils.data import DataLoader from tqdm import tqdm from transformers import HfArgumentParser, AutoConfig, AutoTokenizer from datasets import Dataset, concatenate_datasets from datasets.distributed import split_dataset_by_node from src.arguments import ModelArguments, DataArguments, TrainingArguments from src.data.collator.eval_collator import MultimodalEvalDataCollator from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset from src.eval_utils.metrics import RankingMetrics from src.model.model import MMEBModel from src.model.processor import get_backbone_name, load_processor, COLPALI from src.utils import batch_to_device, print_rank, print_master logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') logger = logging.getLogger(__name__) # --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) --- timing_info = {} token_info = { "vision_tokens": 0, "text_input_tokens": 0, # Refers to the original text token count "text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0. "total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text) } # --- Hook Functions Definition --- def timing_pre_hook(module, input): module_id = id(module) if module_id not in timing_info: timing_info[module_id] = [] timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__)) def timing_post_hook(module, input, output): module_id = id(module) if module_id not in timing_info: # print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})") return timing_info[module_id].append((time.time(), 'post', module.__class__.__name__)) # Collect vision token count (only from Vision Transformer module's post hook) module_name = module.__class__.__name__ if "vision" in module_name.lower() and "transformer" in module_name.lower(): if isinstance(output, torch.Tensor): token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim) elif hasattr(output, 'last_hidden_state'): token_info["vision_tokens"] = output.last_hidden_state.shape[1] def register_model_hooks(model): registered_modules = [] core_model = model # print_master(f"DEBUG: Initial model type in register_model_hooks: {type(model)}") if hasattr(model, 'encoder') and model.encoder is not None: print_master(f"DEBUG: model has 'encoder' attribute. Type of model.encoder: {type(model.encoder)}") else: print_master("WARNING: Model structure does not have an 'encoder' attribute. Registering hooks directly on top-level modules.") # Vision module if hasattr(core_model, 'visual') and core_model.visual is not None: vision_module = core_model.visual vision_module.register_forward_pre_hook(timing_pre_hook) vision_module.register_forward_hook(timing_post_hook) registered_modules.append(vision_module) print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}") else: print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).") # Merger module (if inside visual) - it's part of the vision component if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None: merger_module = core_model.visual.merger merger_module.register_forward_pre_hook(timing_pre_hook) merger_module.register_forward_hook(timing_post_hook) registered_modules.append(merger_module) print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}") else: print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).") # Language model body if hasattr(core_model, 'model') and core_model.model is not None: llm_main_module = core_model.model llm_main_module.register_forward_pre_hook(timing_pre_hook) llm_main_module.register_forward_hook(timing_post_hook) registered_modules.append(llm_main_module) print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}") else: print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).") # LM Head if hasattr(core_model, 'lm_head') and core_model.lm_head is not None: lm_head_module = core_model.lm_head lm_head_module.register_forward_pre_hook(timing_pre_hook) lm_head_module.register_forward_hook(timing_post_hook) registered_modules.append(lm_head_module) print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}") else: print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).") if not registered_modules: print_master("Warning: No major modules found for hook registration. Check model architecture.") return registered_modules def pad_dataset_to_divisible(dataset, world_size): num_samples = len(dataset) if num_samples % world_size == 0: return dataset, num_samples num_to_add = world_size - (num_samples % world_size) padded_size = num_samples + num_to_add padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) padded_dataset = concatenate_datasets([dataset, padding_data]) return padded_dataset, padded_size def encode_embeddings( model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, model_args: ModelArguments, full_dataset: Dataset, encode_side: str, description: str = "Encoding" ) -> tuple[np.ndarray, list, list, list]: # CHANGED: + list for img_token_masks """ Encodes embeddings for a given dataset using the model, handling both standard and late-interaction models in a DDP-safe manner. Returns: - embeddings: np.ndarray - infos_or_ids: list - batch_stats_list: list - img_token_masks: list[None | list[bool]] # NEW """ local_rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 # Check if the model is a late-interaction type is_late_interaction = (model_args.model_backbone == COLPALI) local_embeds = [] local_gt_infos = [] local_max_len = 0 # --- New: List to store statistics for each batch --- batch_stats_list = [] # --- NEW: Collect image token masks locally --- local_img_token_masks = [] # 每个样本一个元素:None 或 [bool, ...] model.eval() # Register hooks for the model once per encode_embeddings call registered_hooks = register_model_hooks(model) # --- NEW: helpers to取mask并序列化 --- def _search_key(obj, key: str): # 递归搜索 dict/list/tuple,找到指定 key if isinstance(obj, dict): if key in obj: return obj[key] for v in obj.values(): r = _search_key(v, key) if r is not None: return r elif isinstance(obj, (list, tuple)): for v in obj: r = _search_key(v, key) if r is not None: return r return None def _to_serializable_mask_list(mask_list, batch_size: int): # 将模型返回的 mask(list/tensor/ndarray/None)转成 [None | list[bool]] * B if mask_list is None: return [None] * batch_size out = [] if isinstance(mask_list, (list, tuple)): for m in mask_list: if m is None: out.append(None) elif torch.is_tensor(m): out.append(m.detach().cpu().tolist()) elif isinstance(m, np.ndarray): out.append(m.tolist()) else: # already python list/bool out.append(m) elif torch.is_tensor(mask_list): # 若是 2D 张量(B, L),直接 tolist() -> list[list[bool/int]] out = mask_list.detach().cpu().tolist() elif isinstance(mask_list, np.ndarray): out = mask_list.tolist() else: # 未知类型,保守返回 None 占位 out = [None] * batch_size # 长度对齐 batch_size if isinstance(out, list): if len(out) < batch_size: out = out + [None] * (batch_size - len(out)) elif len(out) > batch_size: out = out[:batch_size] return out with torch.no_grad(): for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): # --- Reset statistics for each inference pass --- timing_info.clear() token_info["vision_tokens"] = 0 token_info["text_input_tokens"] = 0 token_info["text_output_tokens"] = 0 token_info["total_llm_input_tokens"] = 0 inputs = batch_to_device(inputs, training_args.device) current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1 with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): start_inference_time = time.time() if encode_side == "qry": output = model(qry=inputs) # torch.set_printoptions(threshold=10000) # print('output:', output) # exit() reps = output["qry_reps"].detach() local_gt_infos.extend(dataset_info) else: output = model(tgt=inputs) reps = output["tgt_reps"].detach() local_gt_infos.extend([info["cand_name"] for info in dataset_info]) end_inference_time = time.time() # --- NEW: 提取并保存本 batch 的 image_token_bool_masks --- # 期望 MMEBModel 的 output 中直接或间接包含 'image_token_bool_masks' img_masks_raw = None if isinstance(output, dict): img_masks_raw = _search_key(output, "image_token_bool_masks") # 可选:若你在 MMEBModel 上挂了属性,也可以尝试读取 if img_masks_raw is None and hasattr(model, "image_token_bool_masks"): img_masks_raw = getattr(model, "image_token_bool_masks") img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size) local_img_token_masks.extend(img_masks_serializable) # --- Update total LLM input tokens after the model call --- if 'input_ids' in inputs and inputs['input_ids'] is not None: token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1] token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"] token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"]) # --- Collect and Store Batch Statistics --- batch_inference_time = end_inference_time - start_inference_time current_batch_stats = { "batch_size": current_batch_size, "total_inference_time_seconds": batch_inference_time, "module_inference_times": {}, "token_counts": { "visual_tokens": token_info["vision_tokens"], "language_input_tokens_raw": token_info["text_input_tokens"], "llm_total_input_tokens": token_info["total_llm_input_tokens"], "language_output_tokens": token_info["text_output_tokens"], } } # Calculate and store module timings for the current batch for module_obj in registered_hooks: module_id = id(module_obj) module_name = module_obj.__class__.__name__ times = timing_info.get(module_id, []) durations = [] pre_times = {} for t, event_type, _ in times: if event_type == 'pre': pre_times[module_id] = t elif event_type == 'post' and module_id in pre_times: duration = t - pre_times.pop(module_id) durations.append(duration) if durations: current_batch_stats["module_inference_times"][module_name] = { "total": sum(durations), "count": len(durations), "avg": sum(durations) / len(durations) } else: current_batch_stats["module_inference_times"][module_name] = { "total": 0.0, "count": 0, "avg": 0.0 } batch_stats_list.append(current_batch_stats) # --- Debug prints (optional) --- print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---") print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds") print_rank("--- Module Inference Timing Statistics ---") for module_name, stats in current_batch_stats["module_inference_times"].items(): print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s") print_rank("--- Token Count Statistics ---") print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}") print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}") print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}") print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}") if is_late_interaction and reps.dim() == 3: local_max_len = max(local_max_len, reps.shape[1]) local_embeds.append(reps) if not local_embeds: # Handle cases where a rank gets no data return np.array([]), [], [], [] # CHANGED: 4个返回值 # === DDP Synchronization and Padding for Late-Interaction Models === if is_late_interaction: if dist.is_initialized(): # 1: global max length local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device) dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX) global_max_len = local_max_len_tensor.item() else: global_max_len = local_max_len # 2: pad to global max length padded_embeds = [] for reps_batch in local_embeds: if reps_batch.dim() == 3: B, L, H = reps_batch.shape padding_size = global_max_len - L padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0) padded_embeds.append(padded_batch) else: padded_embeds.append(reps_batch) embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous() else: embeds_tensor = torch.cat(local_embeds, dim=0).contiguous() # === Gather embeddings and keys from all ranks === if dist.is_initialized() and full_dataset.num_rows >= world_size: print_master(f"Gathering {encode_side} embeddings across all ranks...") # tensor gather output_shape = list(embeds_tensor.shape) output_shape[0] = full_dataset.num_rows embeds_tensor = embeds_tensor.to(training_args.device) gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor) final_embeddings = gathered_embeds_tensor.cpu().float().numpy() # object gather for infos and stats gathered_gt_infos = [None for _ in range(world_size)] dist.all_gather_object(gathered_gt_infos, local_gt_infos) all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys] gathered_batch_stats = [None for _ in range(world_size)] dist.all_gather_object(gathered_batch_stats, batch_stats_list) all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats] # --- NEW: gather masks --- gathered_masks = [None for _ in range(world_size)] dist.all_gather_object(gathered_masks, local_img_token_masks) all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list] else: all_gt_infos = local_gt_infos final_embeddings = embeds_tensor.cpu().float().numpy() all_batch_stats = batch_stats_list all_img_token_masks = local_img_token_masks # NEW return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks # CHANGED def main(): if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) local_rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 # DEBUG PRINTS for Distributed Setup print_master("Distributed init debug info:") print_master(f"RANK: {os.environ.get('RANK')}") print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}") print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}") print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}") print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}") if dist.is_initialized(): print_rank(f"dist.get_rank(): {dist.get_rank()}") print_rank(f"dist.get_world_size(): {dist.get_world_size()}") for arg in sys.argv: if arg.startswith("--local-rank="): rank = arg.split("=")[1] sys.argv.remove(arg) sys.argv.append('--local_rank') sys.argv.append(rank) parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args: ModelArguments data_args: DataArguments training_args: TrainingArguments os.makedirs(data_args.encode_output_path, exist_ok=True) # --- Model Loading --- hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) if not getattr(model_args, "model_backbone", None): model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) setattr(model_args, 'model_backbone', model_backbone) setattr(training_args, 'model_backbone', model_backbone) print_master(f'Model Backbone: {model_args.model_backbone}') # --- DDP-Safe Model Loading --- # Step 1: Only the master process (rank 0) downloads the model. if local_rank == 0: processor = load_processor(model_args, data_args) model = MMEBModel.load(model_args, is_trainable=False, processor=processor) print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...") # Step 2: All processes wait here. The non-master processes will pause # until the master process (rank 0) finishes downloading and exits this barrier. if torch.distributed.is_initialized(): torch.distributed.barrier() # Step 3: Now that the model is cached, the non-master processes load it from the local cache. if local_rank != 0: print_rank(f"Loading the model from cache...") processor = load_processor(model_args, data_args) time.sleep(random.randint(2 * local_rank, 3 * local_rank)) model = MMEBModel.load(model_args, is_trainable=False, processor=processor) model.eval() model = model.to(training_args.device, dtype=torch.bfloat16) with open(data_args.dataset_config, 'r') as yaml_file: dataset_configs = yaml.safe_load(yaml_file) # --- Main Evaluation Loop --- for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()): # Initialize task-level statistics accumulators for QUERY query_total_stats = { "total_inference_time_seconds": 0.0, "module_inference_times": { "Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0}, "PatchMerger": {"total": 0.0, "count": 0}, "Qwen2VLModel": {"total": 0.0, "count": 0}, "Linear": {"total": 0.0, "count": 0}, }, "token_counts": { "visual_tokens": 0, "language_input_tokens_raw": 0, "llm_total_input_tokens": 0, "language_output_tokens": 0, }, "data_point_count": 0 # Number of image-text pairs processed } # Initialize task-level statistics accumulators for CANDIDATE cand_total_stats = { "total_inference_time_seconds": 0.0, "module_inference_times": { "Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0}, "PatchMerger": {"total": 0.0, "count": 0}, "Qwen2VLModel": {"total": 0.0, "count": 0}, "Linear": {"total": 0.0, "count": 0}, }, "token_counts": { "visual_tokens": 0, "language_input_tokens_raw": 0, "llm_total_input_tokens": 0, "language_output_tokens": 0, }, "data_point_count": 0 # Number of image-text pairs processed } if dist.is_initialized(): dist.barrier() print_master(f"\n--- Evaluating {dataset_name} ---") query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry") cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt") dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") # New: Define distinct paths for query and candidate inference statistics output query_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_inference_stats.json") cand_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_inference_stats.json") do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path) do_cand = not os.path.exists(cand_embed_path) if do_query or do_cand: if data_args.data_basedir is not None: # Construct full paths for data files if --data_basedir is provided for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: if data_args.data_basedir and task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # Pad datasets to be divisible by world_size before splitting if dist.is_initialized(): padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size) padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size) eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size) eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size) else: padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # --- 1. Compute Query Embeddings --- if do_query: print_master("Encoding queries...") eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") eval_qry_loader = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers) # Modified: capture batch_stats_list query_embeds, gt_infos, qry_batch_stats, qry_img_masks = encode_embeddings(model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries for {dataset_name}") # Accumulate query statistics for batch_stat in qry_batch_stats: batch_size = batch_stat["batch_size"] query_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"] for module_name, module_stats in batch_stat["module_inference_times"].items(): if module_name in query_total_stats["module_inference_times"]: query_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"] query_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"] query_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size query_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size query_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size query_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size query_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items query_embeds = query_embeds[:len(full_eval_qry_dataset)] gt_infos = gt_infos[:len(full_eval_qry_dataset)] if local_rank == 0: with open(query_embed_path, 'wb') as f: pickle.dump(query_embeds, f) with open(dataset_info_path, 'w') as f: for info in gt_infos: f.write(json.dumps(info) + '\n') print_master(f"Saved query embeddings to {query_embed_path}") qry_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_img_token_masks.jsonl") with open(qry_img_masks_path, 'w', encoding='utf-8') as f: for i, m in enumerate(qry_img_masks[:len(full_eval_qry_dataset)]): f.write(json.dumps({"index": i, "mask": m}, ensure_ascii=False) + "\n") print_master(f"Saved query image token masks to {qry_img_masks_path}") # Save query-specific inference statistics if query_total_stats["data_point_count"] > 0: final_query_stats = { "task_name": dataset_name, "encode_side": "query", "data_point_count": query_total_stats["data_point_count"], "inference_times": { "total_inference_time_seconds": query_total_stats["total_inference_time_seconds"], "avg_inference_time_per_item_seconds": query_total_stats["total_inference_time_seconds"] / query_total_stats["data_point_count"], "module_average_times_per_call": {}, "module_total_times_seconds": {}, "module_calls_count": {}, }, "token_counts": { "total_visual_tokens": query_total_stats["token_counts"]["visual_tokens"], "avg_visual_tokens_per_item": query_total_stats["token_counts"]["visual_tokens"] / query_total_stats["data_point_count"], "total_language_input_tokens_raw": query_total_stats["token_counts"]["language_input_tokens_raw"], "avg_language_input_tokens_raw_per_item": query_total_stats["token_counts"]["language_input_tokens_raw"] / query_total_stats["data_point_count"], "total_llm_total_input_tokens": query_total_stats["token_counts"]["llm_total_input_tokens"], "avg_llm_total_input_tokens_per_item": query_total_stats["token_counts"]["llm_total_input_tokens"] / query_total_stats["data_point_count"], "total_language_output_tokens": query_total_stats["token_counts"]["language_output_tokens"], "avg_language_output_tokens_per_item": query_total_stats["token_counts"]["language_output_tokens"] / query_total_stats["data_point_count"], } } for module_name, stats in query_total_stats["module_inference_times"].items(): final_query_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"] final_query_stats["inference_times"]["module_calls_count"][module_name] = stats["count"] if stats["count"] > 0: final_query_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"] else: final_query_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0 with open(query_inference_stats_path, 'w', encoding='utf-8') as f: json.dump(final_query_stats, f, ensure_ascii=False, indent=4) print_master(f"Query inference statistics for {dataset_name} saved to: {query_inference_stats_path}") else: print_master(f"No query data processed for {dataset_name}, skipping query inference statistics output.") if dist.is_initialized(): dist.barrier() # --- 2. Compute Candidate Embeddings --- if do_cand: print_master("Encoding candidates...") eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") eval_cand_loader = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) # Modified: capture batch_stats_list cand_embeds, all_cand_ids, cand_batch_stats, cand_img_masks = encode_embeddings(model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates for {dataset_name}") # Accumulate candidate statistics (similar logic as query) for batch_stat in cand_batch_stats: batch_size = batch_stat["batch_size"] cand_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"] for module_name, module_stats in batch_stat["module_inference_times"].items(): if module_name in cand_total_stats["module_inference_times"]: cand_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"] cand_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"] cand_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size cand_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size cand_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size cand_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size cand_total_stats["data_point_count"] += batch_size # Accumulate the number of processed items cand_embeds = cand_embeds[:len(full_eval_cand_dataset)] all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)] if local_rank == 0: cand_embed_dict = {cand_id: embed for cand_id, embed in zip(all_cand_ids, cand_embeds)} with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f) print_master(f"Saved candidate embeddings to {cand_embed_path}") cand_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_img_token_masks.jsonl") with open(cand_img_masks_path, 'w', encoding='utf-8') as f: for cid, m in zip(all_cand_ids[:len(full_eval_cand_dataset)], cand_img_masks[:len(full_eval_cand_dataset)]): f.write(json.dumps({"cand_id": str(cid), "mask": m}, ensure_ascii=False) + "\n") print_master(f"Saved candidate image token masks to {cand_img_masks_path}") # Save candidate-specific inference statistics if cand_total_stats["data_point_count"] > 0: final_cand_stats = { "task_name": dataset_name, "encode_side": "candidate", "data_point_count": cand_total_stats["data_point_count"], "inference_times": { "total_inference_time_seconds": cand_total_stats["total_inference_time_seconds"], "avg_inference_time_per_item_seconds": cand_total_stats["total_inference_time_seconds"] / cand_total_stats["data_point_count"], "module_average_times_per_call": {}, "module_total_times_seconds": {}, "module_calls_count": {}, }, "token_counts": { "total_visual_tokens": cand_total_stats["token_counts"]["visual_tokens"], "avg_visual_tokens_per_item": cand_total_stats["token_counts"]["visual_tokens"] / cand_total_stats["data_point_count"], "total_language_input_tokens_raw": cand_total_stats["token_counts"]["language_input_tokens_raw"], "avg_language_input_tokens_raw_per_item": cand_total_stats["token_counts"]["language_input_tokens_raw"] / cand_total_stats["data_point_count"], "total_llm_total_input_tokens": cand_total_stats["token_counts"]["llm_total_input_tokens"], "avg_llm_total_input_tokens_per_item": cand_total_stats["token_counts"]["llm_total_input_tokens"] / cand_total_stats["data_point_count"], "total_language_output_tokens": cand_total_stats["token_counts"]["language_output_tokens"], "avg_language_output_tokens_per_item": cand_total_stats["token_counts"]["language_output_tokens"] / cand_total_stats["data_point_count"], } } for module_name, stats in cand_total_stats["module_inference_times"].items(): final_cand_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"] final_cand_stats["inference_times"]["module_calls_count"][module_name] = stats["count"] if stats["count"] > 0: final_cand_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"] else: final_cand_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0 with open(cand_inference_stats_path, 'w', encoding='utf-8') as f: json.dump(final_cand_stats, f, ensure_ascii=False, indent=4) print_master(f"Candidate inference statistics for {dataset_name} saved to: {cand_inference_stats_path}") else: print_master(f"No candidate data processed for {dataset_name}, skipping candidate inference statistics output.") if dist.is_initialized(): dist.barrier() # --- 3. Compute Scores (on master rank only) --- score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") #################################################################################### pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl") score_detail_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_details.jsonl") # 新文件,存相似度分数 def append_score_detail(score_detail_list, qid, ranked_indices, score_vector, cand_ids, labels): """追加一个 query 的候选分数详情""" score_detail_list.append({ "qid": int(qid), "cand_scores": [ {"cand_id": str(cand_ids[i]), "score": float(score_vector[i])} for i in ranked_indices ], "label": labels }) #################################################################################### if local_rank == 0: if os.path.exists(score_path): try: with open(score_path, "r") as f: score_dict = json.load(f) print_master(f"Score of {dataset_name} (loaded from previous run): {score_path}") formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} print_master(formatted) # No `continue` here, as we want to ensure other files are processed/generated except Exception as e: print_master(f"Failed to load score for {dataset_name}, proceeding to recompute. Error: {e}") # Proceed with score computation if not loaded or failed to load with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f) with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f) gt_infos = [json.loads(l) for l in open(dataset_info_path)] pred_dicts = [] score_detail_dicts = []################################### rank_against_all_candidates = task_config.get("eval_type", "global") == "global" # if rank_against_all_candidates: # cand_keys = list(cand_embed_dict.keys()) # cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys]) # # Handle late-interaction scoring # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] # qry_embed = torch.from_numpy(qry_embeds) # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] # scores = processor.score(qry_embed, cand_embeds, batch_size=64) # use ColPali score function # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist() # scores = scores.cpu().numpy() # else: # Dense # cosine_scores = np.dot(qry_embeds, cand_embeds.T) # ranked_candids = np.argsort(-cosine_scores, axis=1) ##################################################### if rank_against_all_candidates: cand_keys = list(cand_embed_dict.keys()) cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys]) if qry_embeds.ndim == 3: # Late-interaction qry_embed_t = torch.from_numpy(qry_embeds) cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds] sim_matrix = processor.score(qry_embed_t, cand_embeds_t, batch_size=64).cpu().numpy() # [N_q, N_c] else: # Dense sim_matrix = np.dot(qry_embeds, cand_embeds.T) # [N_q, N_c] ranked_candids = np.argsort(-sim_matrix, axis=1) ######################################################### for qid, (ranked_candid, gt_info) in tqdm(enumerate(zip(ranked_candids, gt_infos)), desc=f"Calculating scores for {dataset_name}"): rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None assert rel_scores is None or len(rel_docids) == len(rel_scores) pred_dicts.append({ "prediction": [cand_keys[i] for i in ranked_candid], "label": rel_docids, "rel_scores": rel_scores, }) ################################# 新增:详细相似度字典 append_score_detail(score_detail_dicts, qid, ranked_candid, sim_matrix[qid], cand_keys, rel_docids) ######################################## # else: # for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"): # cand_embeds = np.stack([cand_embed_dict[key] for key in gt_info["cand_names"]]) # if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] # qry_embed = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] # scores = processor.score(qry_embed, cand_embeds, batch_size=1024) # use ColPali score function # ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()[0] # else: # cosine_score = np.dot(qry_embed, cand_embeds.T) # ranked_candids = np.argsort(-cosine_score) # rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] # rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None # assert rel_scores is None or len(rel_docids) == len(rel_scores) # pred_dicts.append({ # "prediction": [gt_info["cand_names"][i] for i in ranked_candids], # "label": rel_docids, # "rel_scores": rel_scores, # }) ####################################################################### else: # 非全局 for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"): cand_ids_local = gt_info["cand_names"] cand_embeds = np.stack([cand_embed_dict[key] for key in cand_ids_local]) if qry_embeds.ndim == 3: # Late-interaction qry_embed_t = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # [1, Lq, H] cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds] sim_vec = processor.score(qry_embed_t, cand_embeds_t, batch_size=1024).cpu().numpy()[0] # [N_c] else: # Dense sim_vec = np.dot(qry_embed, cand_embeds.T) # [N_c] ranked_indices = np.argsort(-sim_vec) rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None assert rel_scores is None or len(rel_docids) == len(rel_scores) pred_dicts.append({ "prediction": [cand_ids_local[i] for i in ranked_indices], "label": rel_docids, "rel_scores": rel_scores, }) # 新增:分数详情 append_score_detail(score_detail_dicts, qid, ranked_indices, sim_vec, cand_ids_local, rel_docids) ########################################## 保存预测和分数 with open(score_detail_path, "w") as f: # 新增 for detail in score_detail_dicts: f.write(json.dumps(detail) + '\n') print_master(f"Detailed score file saved to: {score_detail_path}") metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] metrics = RankingMetrics(metrics_to_report) score_dict = metrics.evaluate(pred_dicts) formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} score_dict["num_pred"] = len(pred_dicts) score_dict["num_data"] = len(gt_infos) print_master(f"Score of {dataset_name}:") print_master(formatted) print_master(f"Outputting final score to: {score_path}") with open(score_path, "w") as f: json.dump(score_dict, f, indent=4) with open(pred_path, "w") as f: for pred in pred_dicts: f.write(json.dumps(pred) + '\n') #################################################################### score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl") metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] metrics = RankingMetrics(metrics_to_report) score_dict = metrics.evaluate(pred_dicts) formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} score_dict["num_pred"] = len(pred_dicts) score_dict["num_data"] = len(gt_infos) print_master(f"Score of {dataset_name}:") print_master(formatted) print_master(f"Outputting final score to: {score_path}") with open(score_path, "w") as f: json.dump(score_dict, f, indent=4) with open(pred_path, "w") as f: for pred in pred_dicts: f.write(json.dumps(pred) + '\n') if __name__ == '__main__': main()