|
|
import datetime |
|
|
import logging |
|
|
import json |
|
|
import random |
|
|
import time |
|
|
import numpy as np |
|
|
import os |
|
|
import pickle |
|
|
import sys |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.nn.functional as F |
|
|
import yaml |
|
|
import transformers |
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
from tqdm import tqdm |
|
|
from transformers import HfArgumentParser, AutoConfig, AutoTokenizer |
|
|
from datasets import Dataset, concatenate_datasets |
|
|
from datasets.distributed import split_dataset_by_node |
|
|
|
|
|
from src.arguments_vision_compression import ModelArguments, DataArguments, TrainingArguments |
|
|
from src.data.collator.eval_collator import MultimodalEvalDataCollator |
|
|
from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset |
|
|
from src.eval_utils.metrics import RankingMetrics |
|
|
from src.model.model_vision_compression import MMEBModel |
|
|
from src.model.processor import get_backbone_name, load_processor, COLPALI |
|
|
from src.utils import batch_to_device, print_rank, print_master |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
timing_info = {} |
|
|
token_info = { |
|
|
"vision_tokens": 0, |
|
|
"text_input_tokens": 0, |
|
|
"text_output_tokens": 0, |
|
|
"total_llm_input_tokens": 0, |
|
|
} |
|
|
|
|
|
|
|
|
def timing_pre_hook(module, input): |
|
|
module_id = id(module) |
|
|
if module_id not in timing_info: |
|
|
timing_info[module_id] = [] |
|
|
timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__)) |
|
|
|
|
|
def timing_post_hook(module, input, output): |
|
|
module_id = id(module) |
|
|
if module_id not in timing_info: |
|
|
|
|
|
return |
|
|
|
|
|
timing_info[module_id].append((time.time(), 'post', module.__class__.__name__)) |
|
|
|
|
|
|
|
|
module_name = module.__class__.__name__ |
|
|
if "vision" in module_name.lower() and "transformer" in module_name.lower(): |
|
|
out = output |
|
|
|
|
|
if isinstance(out, (tuple, list)) and len(out) > 0: |
|
|
out = out[0] |
|
|
|
|
|
if torch.is_tensor(out): |
|
|
|
|
|
if out.dim() == 2: |
|
|
token_info["vision_tokens"] = out.shape[0] |
|
|
elif out.dim() == 3: |
|
|
token_info["vision_tokens"] = out.shape[1] |
|
|
elif hasattr(out, "last_hidden_state") and torch.is_tensor(out.last_hidden_state): |
|
|
token_info["vision_tokens"] = out.last_hidden_state.shape[1] |
|
|
|
|
|
def register_model_hooks(model): |
|
|
registered_modules = [] |
|
|
|
|
|
core_model = model |
|
|
|
|
|
|
|
|
if hasattr(model, 'encoder') and model.encoder is not None: |
|
|
print_master(f"DEBUG: model has 'encoder' attribute. Type of model.encoder: {type(model.encoder)}") |
|
|
else: |
|
|
print_master("WARNING: Model structure does not have an 'encoder' attribute. Registering hooks directly on top-level modules.") |
|
|
|
|
|
|
|
|
if hasattr(core_model, 'visual') and core_model.visual is not None: |
|
|
vision_module = core_model.visual |
|
|
vision_module.register_forward_pre_hook(timing_pre_hook) |
|
|
vision_module.register_forward_hook(timing_post_hook) |
|
|
registered_modules.append(vision_module) |
|
|
print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}") |
|
|
else: |
|
|
print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).") |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None: |
|
|
merger_module = core_model.visual.merger |
|
|
merger_module.register_forward_pre_hook(timing_pre_hook) |
|
|
merger_module.register_forward_hook(timing_post_hook) |
|
|
registered_modules.append(merger_module) |
|
|
print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}") |
|
|
else: |
|
|
print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).") |
|
|
|
|
|
|
|
|
if hasattr(core_model, 'model') and core_model.model is not None: |
|
|
llm_main_module = core_model.model |
|
|
llm_main_module.register_forward_pre_hook(timing_pre_hook) |
|
|
llm_main_module.register_forward_hook(timing_post_hook) |
|
|
registered_modules.append(llm_main_module) |
|
|
print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}") |
|
|
else: |
|
|
print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).") |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(core_model, 'lm_head') and core_model.lm_head is not None: |
|
|
lm_head_module = core_model.lm_head |
|
|
lm_head_module.register_forward_pre_hook(timing_pre_hook) |
|
|
lm_head_module.register_forward_hook(timing_post_hook) |
|
|
registered_modules.append(lm_head_module) |
|
|
print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}") |
|
|
else: |
|
|
print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).") |
|
|
|
|
|
|
|
|
if not registered_modules: |
|
|
print_master("Warning: No major modules found for hook registration. Check model architecture.") |
|
|
return registered_modules |
|
|
|
|
|
|
|
|
def pad_dataset_to_divisible(dataset, world_size): |
|
|
num_samples = len(dataset) |
|
|
if num_samples % world_size == 0: |
|
|
return dataset, num_samples |
|
|
|
|
|
num_to_add = world_size - (num_samples % world_size) |
|
|
padded_size = num_samples + num_to_add |
|
|
|
|
|
padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) |
|
|
padded_dataset = concatenate_datasets([dataset, padding_data]) |
|
|
return padded_dataset, padded_size |
|
|
|
|
|
def encode_embeddings( |
|
|
model: MMEBModel, |
|
|
loader: DataLoader, |
|
|
training_args: TrainingArguments, |
|
|
model_args: ModelArguments, |
|
|
full_dataset: Dataset, |
|
|
encode_side: str, |
|
|
description: str = "Encoding" |
|
|
) -> tuple[np.ndarray, list, list, list]: |
|
|
""" |
|
|
Encodes embeddings for a given dataset using the model, handling both standard and |
|
|
late-interaction models in a DDP-safe manner. |
|
|
Returns: |
|
|
- embeddings: np.ndarray |
|
|
- infos_or_ids: list |
|
|
- batch_stats_list: list |
|
|
- img_token_masks: list[None | list[bool]] # NEW |
|
|
""" |
|
|
local_rank = dist.get_rank() if dist.is_initialized() else 0 |
|
|
world_size = dist.get_world_size() if dist.is_initialized() else 1 |
|
|
|
|
|
|
|
|
is_late_interaction = (model_args.model_backbone == COLPALI) |
|
|
|
|
|
local_embeds = [] |
|
|
local_gt_infos = [] |
|
|
local_max_len = 0 |
|
|
|
|
|
|
|
|
batch_stats_list = [] |
|
|
|
|
|
|
|
|
local_img_token_masks = [] |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
registered_hooks = register_model_hooks(model) |
|
|
|
|
|
|
|
|
def _search_key(obj, key: str): |
|
|
|
|
|
if isinstance(obj, dict): |
|
|
if key in obj: |
|
|
return obj[key] |
|
|
for v in obj.values(): |
|
|
r = _search_key(v, key) |
|
|
if r is not None: |
|
|
return r |
|
|
elif isinstance(obj, (list, tuple)): |
|
|
for v in obj: |
|
|
r = _search_key(v, key) |
|
|
if r is not None: |
|
|
return r |
|
|
return None |
|
|
|
|
|
def _to_serializable_mask_list(mask_list, batch_size: int): |
|
|
|
|
|
if mask_list is None: |
|
|
return [None] * batch_size |
|
|
|
|
|
out = [] |
|
|
if isinstance(mask_list, (list, tuple)): |
|
|
for m in mask_list: |
|
|
if m is None: |
|
|
out.append(None) |
|
|
elif torch.is_tensor(m): |
|
|
out.append(m.detach().cpu().tolist()) |
|
|
elif isinstance(m, np.ndarray): |
|
|
out.append(m.tolist()) |
|
|
else: |
|
|
|
|
|
out.append(m) |
|
|
elif torch.is_tensor(mask_list): |
|
|
|
|
|
out = mask_list.detach().cpu().tolist() |
|
|
elif isinstance(mask_list, np.ndarray): |
|
|
out = mask_list.tolist() |
|
|
else: |
|
|
|
|
|
out = [None] * batch_size |
|
|
|
|
|
|
|
|
if isinstance(out, list): |
|
|
if len(out) < batch_size: |
|
|
out = out + [None] * (batch_size - len(out)) |
|
|
elif len(out) > batch_size: |
|
|
out = out[:batch_size] |
|
|
return out |
|
|
|
|
|
with torch.no_grad(): |
|
|
for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): |
|
|
|
|
|
timing_info.clear() |
|
|
token_info["vision_tokens"] = 0 |
|
|
token_info["text_input_tokens"] = 0 |
|
|
token_info["text_output_tokens"] = 0 |
|
|
token_info["total_llm_input_tokens"] = 0 |
|
|
|
|
|
inputs = batch_to_device(inputs, training_args.device) |
|
|
current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1 |
|
|
|
|
|
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): |
|
|
start_inference_time = time.time() |
|
|
if encode_side == "qry": |
|
|
output = model(qry=inputs) |
|
|
|
|
|
|
|
|
|
|
|
reps = output["qry_reps"].detach() |
|
|
local_gt_infos.extend(dataset_info) |
|
|
else: |
|
|
output = model(tgt=inputs) |
|
|
reps = output["tgt_reps"].detach() |
|
|
local_gt_infos.extend([info["cand_name"] for info in dataset_info]) |
|
|
end_inference_time = time.time() |
|
|
|
|
|
|
|
|
|
|
|
img_masks_raw = None |
|
|
if isinstance(output, dict): |
|
|
img_masks_raw = _search_key(output, "image_token_bool_masks") |
|
|
|
|
|
if img_masks_raw is None and hasattr(model, "image_token_bool_masks"): |
|
|
img_masks_raw = getattr(model, "image_token_bool_masks") |
|
|
|
|
|
img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size) |
|
|
local_img_token_masks.extend(img_masks_serializable) |
|
|
|
|
|
|
|
|
if 'input_ids' in inputs and inputs['input_ids'] is not None: |
|
|
token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1] |
|
|
token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"] |
|
|
token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"]) |
|
|
|
|
|
|
|
|
batch_inference_time = end_inference_time - start_inference_time |
|
|
|
|
|
current_batch_stats = { |
|
|
"batch_size": current_batch_size, |
|
|
"total_inference_time_seconds": batch_inference_time, |
|
|
"module_inference_times": {}, |
|
|
"token_counts": { |
|
|
"visual_tokens": token_info["vision_tokens"], |
|
|
"language_input_tokens_raw": token_info["text_input_tokens"], |
|
|
"llm_total_input_tokens": token_info["total_llm_input_tokens"], |
|
|
"language_output_tokens": token_info["text_output_tokens"], |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for module_obj in registered_hooks: |
|
|
module_id = id(module_obj) |
|
|
module_name = module_obj.__class__.__name__ |
|
|
times = timing_info.get(module_id, []) |
|
|
durations = [] |
|
|
pre_times = {} |
|
|
for t, event_type, _ in times: |
|
|
if event_type == 'pre': |
|
|
pre_times[module_id] = t |
|
|
elif event_type == 'post' and module_id in pre_times: |
|
|
duration = t - pre_times.pop(module_id) |
|
|
durations.append(duration) |
|
|
|
|
|
if durations: |
|
|
current_batch_stats["module_inference_times"][module_name] = { |
|
|
"total": sum(durations), |
|
|
"count": len(durations), |
|
|
"avg": sum(durations) / len(durations) |
|
|
} |
|
|
else: |
|
|
current_batch_stats["module_inference_times"][module_name] = { |
|
|
"total": 0.0, |
|
|
"count": 0, |
|
|
"avg": 0.0 |
|
|
} |
|
|
|
|
|
batch_stats_list.append(current_batch_stats) |
|
|
|
|
|
|
|
|
print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---") |
|
|
print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds") |
|
|
print_rank("--- Module Inference Timing Statistics ---") |
|
|
for module_name, stats in current_batch_stats["module_inference_times"].items(): |
|
|
print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s") |
|
|
print_rank("--- Token Count Statistics ---") |
|
|
print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}") |
|
|
print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}") |
|
|
print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}") |
|
|
print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}") |
|
|
|
|
|
if is_late_interaction and reps.dim() == 3: |
|
|
local_max_len = max(local_max_len, reps.shape[1]) |
|
|
|
|
|
local_embeds.append(reps) |
|
|
|
|
|
if not local_embeds: |
|
|
|
|
|
return np.array([]), [], [], [] |
|
|
|
|
|
|
|
|
if is_late_interaction: |
|
|
if dist.is_initialized(): |
|
|
|
|
|
local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device) |
|
|
dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX) |
|
|
global_max_len = local_max_len_tensor.item() |
|
|
else: |
|
|
global_max_len = local_max_len |
|
|
|
|
|
|
|
|
padded_embeds = [] |
|
|
for reps_batch in local_embeds: |
|
|
if reps_batch.dim() == 3: |
|
|
B, L, H = reps_batch.shape |
|
|
padding_size = global_max_len - L |
|
|
padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0) |
|
|
padded_embeds.append(padded_batch) |
|
|
else: |
|
|
padded_embeds.append(reps_batch) |
|
|
|
|
|
embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous() |
|
|
else: |
|
|
embeds_tensor = torch.cat(local_embeds, dim=0).contiguous() |
|
|
|
|
|
|
|
|
if dist.is_initialized() and full_dataset.num_rows >= world_size: |
|
|
print_master(f"Gathering {encode_side} embeddings across all ranks...") |
|
|
|
|
|
|
|
|
output_shape = list(embeds_tensor.shape) |
|
|
output_shape[0] = full_dataset.num_rows |
|
|
embeds_tensor = embeds_tensor.to(training_args.device) |
|
|
gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) |
|
|
dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor) |
|
|
final_embeddings = gathered_embeds_tensor.cpu().float().numpy() |
|
|
|
|
|
|
|
|
gathered_gt_infos = [None for _ in range(world_size)] |
|
|
dist.all_gather_object(gathered_gt_infos, local_gt_infos) |
|
|
all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys] |
|
|
|
|
|
gathered_batch_stats = [None for _ in range(world_size)] |
|
|
dist.all_gather_object(gathered_batch_stats, batch_stats_list) |
|
|
all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats] |
|
|
|
|
|
|
|
|
gathered_masks = [None for _ in range(world_size)] |
|
|
dist.all_gather_object(gathered_masks, local_img_token_masks) |
|
|
all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list] |
|
|
else: |
|
|
all_gt_infos = local_gt_infos |
|
|
final_embeddings = embeds_tensor.cpu().float().numpy() |
|
|
all_batch_stats = batch_stats_list |
|
|
all_img_token_masks = local_img_token_masks |
|
|
|
|
|
return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks |
|
|
|
|
|
|
|
|
def main(): |
|
|
if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): |
|
|
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) |
|
|
local_rank = dist.get_rank() if dist.is_initialized() else 0 |
|
|
world_size = dist.get_world_size() if dist.is_initialized() else 1 |
|
|
|
|
|
print_master("Distributed init debug info:") |
|
|
print_master(f"RANK: {os.environ.get('RANK')}") |
|
|
print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}") |
|
|
print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}") |
|
|
print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}") |
|
|
print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}") |
|
|
if dist.is_initialized(): |
|
|
print_rank(f"dist.get_rank(): {dist.get_rank()}") |
|
|
print_rank(f"dist.get_world_size(): {dist.get_world_size()}") |
|
|
|
|
|
for arg in sys.argv: |
|
|
if arg.startswith("--local-rank="): |
|
|
rank = arg.split("=")[1] |
|
|
sys.argv.remove(arg) |
|
|
sys.argv.append('--local_rank') |
|
|
sys.argv.append(rank) |
|
|
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
if not hasattr(model_args, "vision_compression") or model_args.vision_compression is None: |
|
|
model_args.vision_compression = "token_pooling" |
|
|
model_args: ModelArguments |
|
|
data_args: DataArguments |
|
|
training_args: TrainingArguments |
|
|
os.makedirs(data_args.encode_output_path, exist_ok=True) |
|
|
|
|
|
|
|
|
hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
|
|
if not getattr(model_args, "model_backbone", None): |
|
|
model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) |
|
|
setattr(model_args, 'model_backbone', model_backbone) |
|
|
setattr(training_args, 'model_backbone', model_backbone) |
|
|
print_master(f'Model Backbone: {model_args.model_backbone}') |
|
|
|
|
|
|
|
|
if local_rank == 0: |
|
|
processor = load_processor(model_args, data_args) |
|
|
model = MMEBModel.load(model_args, is_trainable=False, processor=processor) |
|
|
print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...") |
|
|
|
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
|
torch.distributed.barrier() |
|
|
|
|
|
if local_rank != 0: |
|
|
print_rank(f"Loading the model from cache...") |
|
|
processor = load_processor(model_args, data_args) |
|
|
time.sleep(random.randint(2 * local_rank, 3 * local_rank)) |
|
|
model = MMEBModel.load(model_args, is_trainable=False, processor=processor) |
|
|
model.eval() |
|
|
model = model.to(training_args.device, dtype=torch.bfloat16) |
|
|
with open(data_args.dataset_config, 'r') as yaml_file: |
|
|
dataset_configs = yaml.safe_load(yaml_file) |
|
|
|
|
|
|
|
|
|
|
|
for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()): |
|
|
|
|
|
query_total_stats = { |
|
|
"total_inference_time_seconds": 0.0, |
|
|
"module_inference_times": { |
|
|
"Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0}, |
|
|
"PatchMerger": {"total": 0.0, "count": 0}, |
|
|
"Qwen2VLModel": {"total": 0.0, "count": 0}, |
|
|
"Linear": {"total": 0.0, "count": 0}, |
|
|
}, |
|
|
"token_counts": { |
|
|
"visual_tokens": 0, |
|
|
"language_input_tokens_raw": 0, |
|
|
"llm_total_input_tokens": 0, |
|
|
"language_output_tokens": 0, |
|
|
}, |
|
|
"data_point_count": 0 |
|
|
} |
|
|
|
|
|
|
|
|
cand_total_stats = { |
|
|
"total_inference_time_seconds": 0.0, |
|
|
"module_inference_times": { |
|
|
"Qwen2VisionTransformerPretrainedModel": {"total": 0.0, "count": 0}, |
|
|
"PatchMerger": {"total": 0.0, "count": 0}, |
|
|
"Qwen2VLModel": {"total": 0.0, "count": 0}, |
|
|
"Linear": {"total": 0.0, "count": 0}, |
|
|
}, |
|
|
"token_counts": { |
|
|
"visual_tokens": 0, |
|
|
"language_input_tokens_raw": 0, |
|
|
"llm_total_input_tokens": 0, |
|
|
"language_output_tokens": 0, |
|
|
}, |
|
|
"data_point_count": 0 |
|
|
} |
|
|
|
|
|
if dist.is_initialized(): |
|
|
dist.barrier() |
|
|
print_master(f"\n--- Evaluating {dataset_name} ---") |
|
|
|
|
|
query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry") |
|
|
cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt") |
|
|
dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") |
|
|
|
|
|
|
|
|
query_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_inference_stats.json") |
|
|
cand_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_inference_stats.json") |
|
|
|
|
|
|
|
|
do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path) |
|
|
do_cand = not os.path.exists(cand_embed_path) |
|
|
|
|
|
if do_query or do_cand: |
|
|
if data_args.data_basedir is not None: |
|
|
|
|
|
for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: |
|
|
if data_args.data_basedir and task_config.get(key): |
|
|
task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) |
|
|
|
|
|
full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) |
|
|
full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) |
|
|
eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset |
|
|
|
|
|
if dist.is_initialized(): |
|
|
padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size) |
|
|
padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size) |
|
|
eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size) |
|
|
eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size) |
|
|
else: |
|
|
padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset |
|
|
|
|
|
|
|
|
if do_query: |
|
|
print_master("Encoding queries...") |
|
|
eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") |
|
|
eval_qry_loader = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers) |
|
|
|
|
|
|
|
|
query_embeds, gt_infos, qry_batch_stats, qry_img_masks = encode_embeddings(model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries for {dataset_name}") |
|
|
|
|
|
|
|
|
for batch_stat in qry_batch_stats: |
|
|
batch_size = batch_stat["batch_size"] |
|
|
query_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"] |
|
|
for module_name, module_stats in batch_stat["module_inference_times"].items(): |
|
|
if module_name in query_total_stats["module_inference_times"]: |
|
|
query_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"] |
|
|
query_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"] |
|
|
|
|
|
query_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size |
|
|
query_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size |
|
|
query_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size |
|
|
query_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size |
|
|
|
|
|
query_total_stats["data_point_count"] += batch_size |
|
|
|
|
|
query_embeds = query_embeds[:len(full_eval_qry_dataset)] |
|
|
gt_infos = gt_infos[:len(full_eval_qry_dataset)] |
|
|
if local_rank == 0: |
|
|
with open(query_embed_path, 'wb') as f: |
|
|
pickle.dump(query_embeds, f) |
|
|
with open(dataset_info_path, 'w') as f: |
|
|
for info in gt_infos: |
|
|
f.write(json.dumps(info) + '\n') |
|
|
print_master(f"Saved query embeddings to {query_embed_path}") |
|
|
|
|
|
qry_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_img_token_masks.jsonl") |
|
|
with open(qry_img_masks_path, 'w', encoding='utf-8') as f: |
|
|
for i, m in enumerate(qry_img_masks[:len(full_eval_qry_dataset)]): |
|
|
f.write(json.dumps({"index": i, "mask": m}, ensure_ascii=False) + "\n") |
|
|
print_master(f"Saved query image token masks to {qry_img_masks_path}") |
|
|
|
|
|
|
|
|
if query_total_stats["data_point_count"] > 0: |
|
|
final_query_stats = { |
|
|
"task_name": dataset_name, |
|
|
"encode_side": "query", |
|
|
"data_point_count": query_total_stats["data_point_count"], |
|
|
"inference_times": { |
|
|
"total_inference_time_seconds": query_total_stats["total_inference_time_seconds"], |
|
|
"avg_inference_time_per_item_seconds": query_total_stats["total_inference_time_seconds"] / query_total_stats["data_point_count"], |
|
|
"module_average_times_per_call": {}, |
|
|
"module_total_times_seconds": {}, |
|
|
"module_calls_count": {}, |
|
|
}, |
|
|
"token_counts": { |
|
|
"total_visual_tokens": query_total_stats["token_counts"]["visual_tokens"], |
|
|
"avg_visual_tokens_per_item": query_total_stats["token_counts"]["visual_tokens"] / query_total_stats["data_point_count"], |
|
|
"total_language_input_tokens_raw": query_total_stats["token_counts"]["language_input_tokens_raw"], |
|
|
"avg_language_input_tokens_raw_per_item": query_total_stats["token_counts"]["language_input_tokens_raw"] / query_total_stats["data_point_count"], |
|
|
"total_llm_total_input_tokens": query_total_stats["token_counts"]["llm_total_input_tokens"], |
|
|
"avg_llm_total_input_tokens_per_item": query_total_stats["token_counts"]["llm_total_input_tokens"] / query_total_stats["data_point_count"], |
|
|
"total_language_output_tokens": query_total_stats["token_counts"]["language_output_tokens"], |
|
|
"avg_language_output_tokens_per_item": query_total_stats["token_counts"]["language_output_tokens"] / query_total_stats["data_point_count"], |
|
|
} |
|
|
} |
|
|
for module_name, stats in query_total_stats["module_inference_times"].items(): |
|
|
final_query_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"] |
|
|
final_query_stats["inference_times"]["module_calls_count"][module_name] = stats["count"] |
|
|
if stats["count"] > 0: |
|
|
final_query_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"] |
|
|
else: |
|
|
final_query_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0 |
|
|
|
|
|
with open(query_inference_stats_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(final_query_stats, f, ensure_ascii=False, indent=4) |
|
|
print_master(f"Query inference statistics for {dataset_name} saved to: {query_inference_stats_path}") |
|
|
else: |
|
|
print_master(f"No query data processed for {dataset_name}, skipping query inference statistics output.") |
|
|
|
|
|
if dist.is_initialized(): |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
if do_cand: |
|
|
print_master("Encoding candidates...") |
|
|
eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") |
|
|
eval_cand_loader = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) |
|
|
|
|
|
|
|
|
cand_embeds, all_cand_ids, cand_batch_stats, cand_img_masks = encode_embeddings(model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates for {dataset_name}") |
|
|
|
|
|
|
|
|
for batch_stat in cand_batch_stats: |
|
|
batch_size = batch_stat["batch_size"] |
|
|
cand_total_stats["total_inference_time_seconds"] += batch_stat["total_inference_time_seconds"] |
|
|
for module_name, module_stats in batch_stat["module_inference_times"].items(): |
|
|
if module_name in cand_total_stats["module_inference_times"]: |
|
|
cand_total_stats["module_inference_times"][module_name]["total"] += module_stats["total"] |
|
|
cand_total_stats["module_inference_times"][module_name]["count"] += module_stats["count"] |
|
|
|
|
|
cand_total_stats["token_counts"]["visual_tokens"] += batch_stat["token_counts"]["visual_tokens"] * batch_size |
|
|
cand_total_stats["token_counts"]["language_input_tokens_raw"] += batch_stat["token_counts"]["language_input_tokens_raw"] * batch_size |
|
|
cand_total_stats["token_counts"]["llm_total_input_tokens"] += batch_stat["token_counts"]["llm_total_input_tokens"] * batch_size |
|
|
cand_total_stats["token_counts"]["language_output_tokens"] += batch_stat["token_counts"]["language_output_tokens"] * batch_size |
|
|
|
|
|
cand_total_stats["data_point_count"] += batch_size |
|
|
|
|
|
cand_embeds = cand_embeds[:len(full_eval_cand_dataset)] |
|
|
all_cand_ids = all_cand_ids[:len(full_eval_cand_dataset)] |
|
|
|
|
|
if local_rank == 0: |
|
|
cand_embed_dict = {cand_id: embed for cand_id, embed in zip(all_cand_ids, cand_embeds)} |
|
|
with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f) |
|
|
print_master(f"Saved candidate embeddings to {cand_embed_path}") |
|
|
|
|
|
cand_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_img_token_masks.jsonl") |
|
|
with open(cand_img_masks_path, 'w', encoding='utf-8') as f: |
|
|
for cid, m in zip(all_cand_ids[:len(full_eval_cand_dataset)], cand_img_masks[:len(full_eval_cand_dataset)]): |
|
|
f.write(json.dumps({"cand_id": str(cid), "mask": m}, ensure_ascii=False) + "\n") |
|
|
print_master(f"Saved candidate image token masks to {cand_img_masks_path}") |
|
|
|
|
|
|
|
|
if cand_total_stats["data_point_count"] > 0: |
|
|
final_cand_stats = { |
|
|
"task_name": dataset_name, |
|
|
"encode_side": "candidate", |
|
|
"data_point_count": cand_total_stats["data_point_count"], |
|
|
"inference_times": { |
|
|
"total_inference_time_seconds": cand_total_stats["total_inference_time_seconds"], |
|
|
"avg_inference_time_per_item_seconds": cand_total_stats["total_inference_time_seconds"] / cand_total_stats["data_point_count"], |
|
|
"module_average_times_per_call": {}, |
|
|
"module_total_times_seconds": {}, |
|
|
"module_calls_count": {}, |
|
|
}, |
|
|
"token_counts": { |
|
|
"total_visual_tokens": cand_total_stats["token_counts"]["visual_tokens"], |
|
|
"avg_visual_tokens_per_item": cand_total_stats["token_counts"]["visual_tokens"] / cand_total_stats["data_point_count"], |
|
|
"total_language_input_tokens_raw": cand_total_stats["token_counts"]["language_input_tokens_raw"], |
|
|
"avg_language_input_tokens_raw_per_item": cand_total_stats["token_counts"]["language_input_tokens_raw"] / cand_total_stats["data_point_count"], |
|
|
"total_llm_total_input_tokens": cand_total_stats["token_counts"]["llm_total_input_tokens"], |
|
|
"avg_llm_total_input_tokens_per_item": cand_total_stats["token_counts"]["llm_total_input_tokens"] / cand_total_stats["data_point_count"], |
|
|
"total_language_output_tokens": cand_total_stats["token_counts"]["language_output_tokens"], |
|
|
"avg_language_output_tokens_per_item": cand_total_stats["token_counts"]["language_output_tokens"] / cand_total_stats["data_point_count"], |
|
|
} |
|
|
} |
|
|
for module_name, stats in cand_total_stats["module_inference_times"].items(): |
|
|
final_cand_stats["inference_times"]["module_total_times_seconds"][module_name] = stats["total"] |
|
|
final_cand_stats["inference_times"]["module_calls_count"][module_name] = stats["count"] |
|
|
if stats["count"] > 0: |
|
|
final_cand_stats["inference_times"]["module_average_times_per_call"][module_name] = stats["total"] / stats["count"] |
|
|
else: |
|
|
final_cand_stats["inference_times"]["module_average_times_per_call"][module_name] = 0.0 |
|
|
|
|
|
with open(cand_inference_stats_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(final_cand_stats, f, ensure_ascii=False, indent=4) |
|
|
print_master(f"Candidate inference statistics for {dataset_name} saved to: {cand_inference_stats_path}") |
|
|
else: |
|
|
print_master(f"No candidate data processed for {dataset_name}, skipping candidate inference statistics output.") |
|
|
|
|
|
if dist.is_initialized(): |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") |
|
|
|
|
|
pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl") |
|
|
score_detail_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_details.jsonl") |
|
|
def append_score_detail(score_detail_list, qid, ranked_indices, score_vector, cand_ids, labels): |
|
|
"""追加一个 query 的候选分数详情""" |
|
|
score_detail_list.append({ |
|
|
"qid": int(qid), |
|
|
"cand_scores": [ |
|
|
{"cand_id": str(cand_ids[i]), "score": float(score_vector[i])} |
|
|
for i in ranked_indices |
|
|
], |
|
|
"label": labels |
|
|
}) |
|
|
|
|
|
if local_rank == 0: |
|
|
if os.path.exists(score_path): |
|
|
try: |
|
|
with open(score_path, "r") as f: |
|
|
score_dict = json.load(f) |
|
|
print_master(f"Score of {dataset_name} (loaded from previous run): {score_path}") |
|
|
formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} |
|
|
print_master(formatted) |
|
|
|
|
|
except Exception as e: |
|
|
print_master(f"Failed to load score for {dataset_name}, proceeding to recompute. Error: {e}") |
|
|
|
|
|
with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f) |
|
|
with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f) |
|
|
gt_infos = [json.loads(l) for l in open(dataset_info_path)] |
|
|
pred_dicts = [] |
|
|
score_detail_dicts = [] |
|
|
|
|
|
rank_against_all_candidates = task_config.get("eval_type", "global") == "global" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if rank_against_all_candidates: |
|
|
cand_keys = list(cand_embed_dict.keys()) |
|
|
cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys]) |
|
|
|
|
|
if qry_embeds.ndim == 3: |
|
|
qry_embed_t = torch.from_numpy(qry_embeds) |
|
|
cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds] |
|
|
sim_matrix = processor.score(qry_embed_t, cand_embeds_t, batch_size=64).cpu().numpy() |
|
|
else: |
|
|
sim_matrix = np.dot(qry_embeds, cand_embeds.T) |
|
|
|
|
|
ranked_candids = np.argsort(-sim_matrix, axis=1) |
|
|
|
|
|
for qid, (ranked_candid, gt_info) in tqdm(enumerate(zip(ranked_candids, gt_infos)), desc=f"Calculating scores for {dataset_name}"): |
|
|
rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] |
|
|
rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None |
|
|
assert rel_scores is None or len(rel_docids) == len(rel_scores) |
|
|
pred_dicts.append({ |
|
|
"prediction": [cand_keys[i] for i in ranked_candid], |
|
|
"label": rel_docids, |
|
|
"rel_scores": rel_scores, |
|
|
}) |
|
|
|
|
|
append_score_detail(score_detail_dicts, qid, ranked_candid, sim_matrix[qid], cand_keys, rel_docids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), desc=f"Calculating scores for {dataset_name}"): |
|
|
cand_ids_local = gt_info["cand_names"] |
|
|
cand_embeds = np.stack([cand_embed_dict[key] for key in cand_ids_local]) |
|
|
|
|
|
if qry_embeds.ndim == 3: |
|
|
qry_embed_t = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) |
|
|
cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds] |
|
|
sim_vec = processor.score(qry_embed_t, cand_embeds_t, batch_size=1024).cpu().numpy()[0] |
|
|
else: |
|
|
sim_vec = np.dot(qry_embed, cand_embeds.T) |
|
|
|
|
|
ranked_indices = np.argsort(-sim_vec) |
|
|
rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] |
|
|
rel_scores = gt_info["rel_scores"] if "rel_scores" in gt_info else None |
|
|
assert rel_scores is None or len(rel_docids) == len(rel_scores) |
|
|
|
|
|
pred_dicts.append({ |
|
|
"prediction": [cand_ids_local[i] for i in ranked_indices], |
|
|
"label": rel_docids, |
|
|
"rel_scores": rel_scores, |
|
|
}) |
|
|
|
|
|
|
|
|
append_score_detail(score_detail_dicts, qid, ranked_indices, sim_vec, cand_ids_local, rel_docids) |
|
|
|
|
|
|
|
|
with open(score_detail_path, "w") as f: |
|
|
for detail in score_detail_dicts: |
|
|
f.write(json.dumps(detail) + '\n') |
|
|
print_master(f"Detailed score file saved to: {score_detail_path}") |
|
|
|
|
|
metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] |
|
|
metrics = RankingMetrics(metrics_to_report) |
|
|
score_dict = metrics.evaluate(pred_dicts) |
|
|
formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} |
|
|
score_dict["num_pred"] = len(pred_dicts) |
|
|
score_dict["num_data"] = len(gt_infos) |
|
|
print_master(f"Score of {dataset_name}:") |
|
|
print_master(formatted) |
|
|
print_master(f"Outputting final score to: {score_path}") |
|
|
with open(score_path, "w") as f: |
|
|
json.dump(score_dict, f, indent=4) |
|
|
with open(pred_path, "w") as f: |
|
|
for pred in pred_dicts: |
|
|
f.write(json.dumps(pred) + '\n') |
|
|
|
|
|
score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score.json") |
|
|
pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred.jsonl") |
|
|
|
|
|
metrics_to_report = task_config["metrics"] if task_config.get("metrics", None) is not None else ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] |
|
|
metrics = RankingMetrics(metrics_to_report) |
|
|
score_dict = metrics.evaluate(pred_dicts) |
|
|
formatted = {k: f"{v:.4f}" for k, v in score_dict.items()} |
|
|
score_dict["num_pred"] = len(pred_dicts) |
|
|
score_dict["num_data"] = len(gt_infos) |
|
|
print_master(f"Score of {dataset_name}:") |
|
|
print_master(formatted) |
|
|
print_master(f"Outputting final score to: {score_path}") |
|
|
with open(score_path, "w") as f: |
|
|
json.dump(score_dict, f, indent=4) |
|
|
with open(pred_path, "w") as f: |
|
|
for pred in pred_dicts: |
|
|
f.write(json.dumps(pred) + '\n') |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|