import datetime import logging import json import random import time import numpy as np import os import pickle import sys import torch import torch.distributed as dist import torch.nn.functional as F import yaml import math # add this from torch.utils.data import DataLoader from tqdm import tqdm from transformers import HfArgumentParser, AutoConfig from datasets import Dataset, concatenate_datasets from datasets.distributed import split_dataset_by_node from src.arguments import ModelArguments, DataArguments, TrainingArguments from src.data.collator.eval_collator import MultimodalEvalDataCollator from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset from src.eval_utils.metrics import RankingMetrics from src.model.model_analysis import MMEBModel from src.model.processor import get_backbone_name, load_processor, COLPALI from src.utils import batch_to_device, print_rank, print_master import multiprocessing from multiprocessing import Pool, cpu_count logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') logger = logging.getLogger(__name__) def pad_dataset_to_divisible(dataset, world_size): num_samples = len(dataset) if num_samples % world_size == 0: return dataset, num_samples num_to_add = world_size - (num_samples % world_size) padded_size = num_samples + num_to_add padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) padded_dataset = concatenate_datasets([dataset, padding_data]) return padded_dataset, padded_size def encode_embeddings( model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, model_args: ModelArguments, full_dataset: Dataset, encode_side: str, description: str = "Encoding" ) -> tuple[np.ndarray, list]: """ Encodes embeddings for a given dataset using the model, handling both standard and late-interaction models in a DDP-safe manner. """ local_rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 # Check if the model is a late-interaction type is_late_interaction = (model_args.model_backbone == COLPALI) local_embeds = [] local_gt_infos = [] local_max_len = 0 model.eval() with torch.no_grad(): for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): inputs = batch_to_device(inputs, training_args.device) with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): # Determine if encoding query or target based on available keys if encode_side == "qry": output = model(qry=inputs) reps = output["qry_reps"].detach() local_gt_infos.extend(dataset_info) # to retain all information per query else: output = model(tgt=inputs) reps = output["tgt_reps"].detach() local_gt_infos.extend([info["cand_name"] for info in dataset_info]) # to retain ground-truth labels if is_late_interaction and reps.dim() == 3: local_max_len = max(local_max_len, reps.shape[1]) local_embeds.append(reps) if not local_embeds: # Handle cases where a rank gets no data return np.array([]), [] # === DDP Synchronization and Padding for Late-Interaction Models === if is_late_interaction: if dist.is_initialized(): # 1. Find the global maximum sequence length across all ranks local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device) dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX) global_max_len = local_max_len_tensor.item() else: global_max_len = local_max_len # 2. Pad all local embeddings to the global max length padded_embeds = [] for reps_batch in local_embeds: if reps_batch.dim() == 3: B, L, H = reps_batch.shape padding_size = global_max_len - L padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0) padded_embeds.append(padded_batch) else: # Should not happen if model is consistently late-interaction padded_embeds.append(reps_batch) embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous() else: # Standard dense models embeds_tensor = torch.cat(local_embeds, dim=0).contiguous() # === Gather embeddings and keys from all ranks === if dist.is_initialized() and full_dataset.num_rows >= world_size: print_master(f"Gathering {encode_side} embeddings across all ranks...") # Use the more efficient all_gather_into_tensor for tensors output_shape = list(embeds_tensor.shape) output_shape[0] = full_dataset.num_rows embeds_tensor = embeds_tensor.to(training_args.device) gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor) final_embeddings = gathered_embeds_tensor.cpu().float().numpy() # Gather metadata, for which all_gather_object is appropriate gathered_gt_infos = [None for _ in range(world_size)] dist.all_gather_object(gathered_gt_infos, local_gt_infos) all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys] else: all_gt_infos = local_gt_infos final_embeddings = embeds_tensor.cpu().float().numpy() return final_embeddings, all_gt_infos def linear_cka(matrix1, matrix2): """ Compute Linear CKA between two matrices X (N, D1) and Y (N, D2). """ # Center the matrices matrix1 = matrix1 - matrix1.mean(dim=0, keepdim=True) matrix2 = matrix2 - matrix2.mean(dim=0, keepdim=True) # Compute Gram matrices gram1 = torch.mm(matrix1, matrix1.t()) gram2 = torch.mm(matrix2, matrix2.t()) # Compute HSIC hsic_cross = torch.sum(gram1 * gram2) hsic_1 = torch.sum(gram1 * gram1) hsic_2 = torch.sum(gram2 * gram2) # Normalize cka = hsic_cross / (torch.sqrt(hsic_1) * torch.sqrt(hsic_2)) return cka.item() def run_layer_analysis(model, dataset, processor, model_args, data_args, training_args, output_dir, num_samples=64): """ Run Layer-wise CKA and Similarity analysis on a subset of data. """ print_master(f"\n[Analysis] Starting Layer-wise Analysis on {num_samples} samples...") # 1. Prepare a small subset subset_indices = list(range(min(len(dataset), num_samples))) subset = dataset.select(subset_indices) # We need pairs (Query, Positive Candidate) to measure alignment # Assuming dataset yields (query_input, candidate_input) or similar. # But based on existing code, `eval_collator` handles "qry" or "cand" mode. # We will process Queries and their corresponding Ground Truth Candidates manually. # Create loaders from src.data.collator.eval_collator import MultimodalEvalDataCollator qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") # We need to act as if we are processing the subset. # For simplicity, we just take the first batch of size num_samples loader = DataLoader(subset, batch_size=num_samples, collate_fn=qry_collator) # Queries # Fetch one batch of queries try: qry_batch, qry_info = next(iter(loader)) except StopIteration: return # Fetch corresponding candidates # This is tricky because `dataset` might be just queries with metadata. # We need to construct the candidate batch corresponding to these queries. # For the sake of CKA, we can just analyze the QUERY Encoder's evolution. # For Alignment, we need pairs. # Let's Focus on QUERY Encoder Structure (CKA) first. qry_batch = batch_to_device(qry_batch, training_args.device) with torch.no_grad(): with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): # Shape: [Layers, Batch, Dim] all_layers_reps = model.encode_all_layers(qry_batch) num_layers, bsz, dim = all_layers_reps.shape final_layer_reps = all_layers_reps[-1] # [Batch, Dim] cka_scores = [] cos_sim_with_final = [] # 2. Compute Metrics for i in range(num_layers): current_layer_reps = all_layers_reps[i] # Metric A: CKA (Structural Similarity to Final Layer) # Ensure float32 for stability in CKA cka = linear_cka(current_layer_reps.float(), final_layer_reps.float()) cka_scores.append(cka) # Metric B: Average Cosine Similarity with Final Layer (Representation Drift) # Measures how much the vectors change direction individually cos_sim = F.cosine_similarity(current_layer_reps.float(), final_layer_reps.float(), dim=-1).mean().item() cos_sim_with_final.append(cos_sim) # 3. Save Results results = { "layers": list(range(1, num_layers + 1)), "cka_scores": cka_scores, "cos_sim_with_final": cos_sim_with_final } save_path = os.path.join(output_dir, "layer_analysis_results.json") with open(save_path, "w") as f: json.dump(results, f, indent=4) print_master(f"[Analysis] CKA and Similarity results saved to {save_path}") print_master(f"[Analysis] Layer 12 CKA: {cka_scores[min(11, num_layers-1)]:.4f}") return results def main(): if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) local_rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 # DEBUG PRINTS for Distributed Setup print_master("Distributed init debug info:") print_master(f"RANK: {os.environ.get('RANK')}") print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}") print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}") print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}") print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}") if dist.is_initialized(): print_rank(f"dist.get_rank(): {dist.get_rank()}") print_rank(f"dist.get_world_size(): {dist.get_world_size()}") for arg in sys.argv: if arg.startswith("--local-rank="): rank = arg.split("=")[1] sys.argv.remove(arg) sys.argv.append('--local_rank') sys.argv.append(rank) parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args: ModelArguments data_args: DataArguments training_args: TrainingArguments os.makedirs(data_args.encode_output_path, exist_ok=True) # --- Model Loading --- hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) if not getattr(model_args, "model_backbone", None): model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) setattr(model_args, 'model_backbone', model_backbone) setattr(training_args, 'model_backbone', model_backbone) print_master(f'Model Backbone: {model_args.model_backbone}') # --- DDP-Safe Model Loading --- # Step 1: Only the master process (rank 0) downloads the model. if local_rank == 0: processor = load_processor(model_args, data_args) model = MMEBModel.load(model_args, is_trainable=False, processor=processor) print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...") # Step 2: All processes wait here. The non-master processes will pause # until the master process (rank 0) finishes downloading and exits this barrier. if torch.distributed.is_initialized(): torch.distributed.barrier() # Step 3: Now that the model is cached, the non-master processes load it from the local cache. if local_rank != 0: print_rank(f"Loading the model from cache...") processor = load_processor(model_args, data_args) time.sleep(random.randint(2 * local_rank, 3 * local_rank)) model = MMEBModel.load(model_args, is_trainable=False, processor=processor) model.eval() model = model.to(training_args.device, dtype=torch.bfloat16) # model.set_inference_layers(qry_layers=24, tgt_layers=24) # # === INSERT ANALYSIS CODE HERE === # if local_rank == 0: # Run analysis only on master rank # # We need a dataset to run analysis. Let's use the first dataset in the config. # with open(data_args.dataset_config, 'r') as yaml_file: # temp_configs = yaml.safe_load(yaml_file) # first_dataset_name = list(temp_configs.keys())[0] # first_task_config = temp_configs[first_dataset_name] # if data_args.data_basedir is not None: # for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: # if data_args.data_basedir and first_task_config.get(key): # first_task_config[key] = os.path.join(data_args.data_basedir, first_task_config[key]) # print_master(f"--- Running Layer Analysis on {first_dataset_name} ---") # analysis_dataset, _ = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **first_task_config) # # Run the analysis # run_layer_analysis( # model, # analysis_dataset, # processor, # model_args, # data_args, # training_args, # data_args.encode_output_path # ) # if dist.is_initialized(): # dist.barrier() # # === END INSERT === # === INSERT ANALYSIS CODE START === if local_rank == 0: # Run analysis only on master rank import numpy as np # Ensure numpy is imported # Load configs with open(data_args.dataset_config, 'r') as yaml_file: analysis_configs = yaml.safe_load(yaml_file) print_master(f"\n[Global Analysis] Starting Layer-wise Analysis across {len(analysis_configs)} datasets...") all_datasets_cka = [] all_datasets_cos = [] dataset_names_log = [] # Loop through ALL datasets defined in the yaml for d_name, d_config in analysis_configs.items(): print_master(f" > Analyzing dataset: {d_name} ...") # Handle path adjustment if data_args.data_basedir is not None: for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: if d_config.get(key): d_config[key] = os.path.join(data_args.data_basedir, d_config[key]) # Instantiate dataset # catch potential errors to avoid crashing the whole analysis if one dataset fails try: analysis_dataset, _ = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **d_config) # Run Analysis (Reduced samples per dataset to save time, e.g., 32) # We collect local results but don't save individual JSONs to keep folder clean (or you can save if you want) res = run_layer_analysis( model, analysis_dataset, processor, model_args, data_args, training_args, data_args.encode_output_path, num_samples=128 # 32 samples per task is sufficient for CKA ) if res: all_datasets_cka.append(res["cka_scores"]) all_datasets_cos.append(res["cos_sim_with_final"]) dataset_names_log.append(d_name) except Exception as e: print_master(f" [Warning] Failed to analyze {d_name}: {e}") # Compute Averages if all_datasets_cka: avg_cka = np.mean(np.array(all_datasets_cka), axis=0).tolist() avg_cos = np.mean(np.array(all_datasets_cos), axis=0).tolist() std_cka = np.std(np.array(all_datasets_cka), axis=0).tolist() # Calculate Variance/StdDev # Save Global Results global_results = { "layers": res["layers"], # Assume all have same layers "analyzed_datasets": dataset_names_log, "avg_cka_scores": avg_cka, "std_cka_scores": std_cka, "avg_cos_sim": avg_cos, "individual_cka": all_datasets_cka # Save raw data if you want to plot distribution } save_path = os.path.join(data_args.encode_output_path, "global_layer_analysis_avg.json") with open(save_path, "w") as f: json.dump(global_results, f, indent=4) print_master(f"\n[Global Analysis] Completed! Averaged results saved to {save_path}") print_master(f"[Global Analysis] Avg Layer 12 CKA: {avg_cka[11]:.4f} (Std: {std_cka[11]:.4f})") else: print_master("[Global Analysis] No datasets were successfully analyzed.") if dist.is_initialized(): dist.barrier() # === INSERT ANALYSIS CODE END === with open(data_args.dataset_config, 'r') as yaml_file: dataset_configs = yaml.safe_load(yaml_file) # --- Main Evaluation Loop --- for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()): # 0. load dataset if dist.is_initialized(): dist.barrier() print_master(f"--- Evaluating {dataset_name} ---") query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry") cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt") dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path) do_cand = not os.path.exists(cand_embed_path) if do_query or do_cand: if data_args.data_basedir is not None: # Construct full paths for data files if --data_basedir is provided for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: if data_args.data_basedir and task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # Pad datasets to be divisible by world_size before splitting if dist.is_initialized(): padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size) padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size) eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size) eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size) else: padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # --- 1. Compute Query Embeddings --- if do_query: print_master("Encoding queries...") eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") eval_qry_loader = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers) query_embeds, gt_infos = encode_embeddings(model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries for {dataset_name}") query_embeds = query_embeds[:len(full_eval_qry_dataset)] # world_size>1, trim the padded data points gt_infos = gt_infos[:len(full_eval_qry_dataset)] if local_rank == 0: with open(query_embed_path, 'wb') as f: pickle.dump(query_embeds, f) with open(dataset_info_path, 'w') as f: for info in gt_infos: f.write(json.dumps(info) + '\n') print_master(f"Saved query embeddings to {query_embed_path}") if dist.is_initialized(): dist.barrier() # --- 2. Compute Candidate Embeddings --- if do_cand: print_master("Encoding candidates...") eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") eval_cand_loader = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) cand_embeds, all_cand_ids = encode_embeddings(model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates for {dataset_name}") cand_embeds = cand_embeds[:len(full_eval_cand_dataset)] # world_size>1, trim the padded data points all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)] if local_rank == 0: cand_embed_dict = {cand_id: embed for cand_id, embed in zip(all_cand_ids, cand_embeds)} with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f) print_master(f"Saved candidate embeddings to {cand_embed_path}") if dist.is_initialized(): dist.barrier() # --- 3. Compute Scores (on master rank only) --- if local_rank == 0: score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") if os.path.exists(score_path): try: with open(score_path, "r") as f: score_dict = json.load(f) print_master(f"Score of {dataset_name} (loaded from previous run): {score_path}") formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} print_master(formatted) continue except Exception as e: print_master(f"Failed to load score for {dataset_name}, skipping {dataset_name}") with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f) with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f) gt_infos = [json.loads(l) for l in open(dataset_info_path)] pred_dicts = [] rank_against_all_candidates = task_config.get("eval_type", "global") == "global" if rank_against_all_candidates: cand_keys = list(cand_embed_dict.keys()) cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys]) # Handle late-interaction scoring if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] qry_embed = torch.from_numpy(qry_embeds) cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] scores = processor.score(qry_embed, cand_embeds, batch_size=64) # use ColPali score function ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist() else: # Dense cosine_scores = np.dot(qry_embeds, cand_embeds.T) ranked_candids = np.argsort(-cosine_scores, axis=1) for qid, (ranked_candid, gt_info) in tqdm(enumerate(zip(ranked_candids, gt_infos)), desc=f"Calculating scores for {dataset_name}"): rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None assert rel_scores is None or len(rel_docids) == len(rel_scores) pred_dicts.append({ "prediction": [cand_keys[i] for i in ranked_candid], "label": rel_docids, "rel_scores": rel_scores, }) else: for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"): cand_embeds = np.stack([cand_embed_dict[key] for key in gt_info["cand_names"]]) if qry_embeds.ndim == 3: # Query: [N_q, L_q, H] | Candidate: [N_c, L_c, H] qry_embed = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) cand_embeds = [torch.from_numpy(np.array(t)) for t in cand_embeds] scores = processor.score(qry_embed, cand_embeds, batch_size=1024) # use ColPali score function ranked_candids = torch.argsort(-scores, dim=1).cpu().numpy().tolist()[0] else: cosine_score = np.dot(qry_embed, cand_embeds.T) ranked_candids = np.argsort(-cosine_score) rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None assert rel_scores is None or len(rel_docids) == len(rel_scores) pred_dicts.append({ "prediction": [gt_info["cand_names"][i] for i in ranked_candids], "label": rel_docids, "rel_scores": rel_scores, }) score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl") metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] metrics = RankingMetrics(metrics_to_report) score_dict = metrics.evaluate(pred_dicts) formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} score_dict["num_pred"] = len(pred_dicts) score_dict["num_data"] = len(gt_infos) print_master(f"Score of {dataset_name}:") print_master(formatted) print_master(f"Outputting final score to: {score_path}") with open(score_path, "w") as f: json.dump(score_dict, f, indent=4) with open(pred_path, "w") as f: for pred in pred_dicts: f.write(json.dumps(pred) + '\n') if __name__ == "__main__": main()