from typing import Dict import torch import torch.distributed as dist from torch import nn, Tensor from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig from peft import LoraConfig, get_peft_model, PeftModel from src.model.processor import QWEN2_5_VL_TOKENSELECTION from src.arguments import ModelArguments, TrainingArguments from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \ backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V from src.arguments import ModelArguments from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \ QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI from src.model.baseline_backbone.colpali import ColPali from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration from transformers import modeling_utils if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise'] import torch.nn.functional as F import numpy as np # 新增:用于计算层级指标的辅助函数 def analyze_layer_metrics(outputs, input_ids, image_token_id=151655, video_token_id=151656): # 将输入移至 CPU,避免 CUDA 索引问题,并切断梯度(detach) input_ids_cpu = input_ids[0].detach().cpu() # 取 batch 的第一个样本 stats = { "text_sim": [], "attn_dist": [], "text_influence": [] } # 获取 hidden_states (Tuple of GPU Tensors) hidden_states = outputs.hidden_states # 简单的安全性检查 seq_len = len(input_ids_cpu) if seq_len == 0: return stats # --- 在 CPU 上构建 Mask --- # 注意:这里都在 CPU 上操作,完全不会触发 CUDA assert is_vision = (input_ids_cpu == image_token_id) | (input_ids_cpu == video_token_id) is_text = ~is_vision # 排除 EOS (如果是最后一个 token) if seq_len > 0: is_text[seq_len-1] = False # 遍历层 (排除最后一层) for i, layer_hidden in enumerate(hidden_states[:-1]): # 将当前层的 hidden state 移至 CPU # layer_hidden 是 (Batch, Seq, Dim),我们取 [0] -> (Seq, Dim) h = layer_hidden[0].detach().cpu().float() h = F.normalize(h, p=2, dim=-1) # 只选取文本 Token text_h = h[is_text] # (Num_Text, D) if text_h.shape[0] > 1: # 计算余弦相似度矩阵 (CPU 运算) sim_matrix = torch.matmul(text_h, text_h.T) # 排除对角线 n_text = text_h.shape[0] # 创建对角 mask eye_mask = torch.eye(n_text, dtype=torch.bool) mask = ~eye_mask if mask.sum() > 0: avg_sim = sim_matrix[mask].mean().item() else: avg_sim = 1.0 else: avg_sim = 1.0 stats["text_sim"].append(avg_sim) return stats class MMEBModel(nn.Module): TRANSFORMER_CLS = AutoModelForCausalLM def __init__(self, encoder: PreTrainedModel, pooling: str = 'last', normalize: bool = False, temperature: float = 0.02, ): super().__init__() self.config = encoder.config self.encoder = encoder self.pooling = pooling self.normalize = normalize self.temperature = temperature self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') self.is_ddp = dist.is_initialized() if self.is_ddp: self.process_rank = dist.get_rank() self.world_size = dist.get_world_size() def encode_input(self, input, return_analysis=False): # 强制开启 output_hidden_states 和 output_attentions (如果要做分析) output_kwargs = { "return_dict": True, "output_hidden_states": True, } if getattr(self, "model_backbone", None) == INTERNVIDEO2: if "input_ids" in input.keys(): # text side text_output = self.encoder.get_text_encoder()( input["input_ids"], attention_mask=input["attention_mask"], return_dict=True, mode="text", ) text_embeds = text_output.last_hidden_state pooled_text_embeds = text_embeds[:, 0] pooled_output = self.encoder.text_proj(pooled_text_embeds) pooled_output /= pooled_output.norm(dim=-1, keepdim=True) return pooled_output else: _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) vfeat = self.encoder.vision_proj(vfeat) vfeat /= vfeat.norm(dim=-1, keepdim=True) return vfeat elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]: # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen images = [] for imgs in input['images']: # if multi images are given, select the middle frame only if isinstance(imgs, list): imgs = imgs[len(imgs) // 2] assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list images.append(imgs) else: images.append(imgs) pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) return pooled_output elif getattr(self, "model_backbone", None) == COLPALI: pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) return pooled_output elif getattr(self, "model_backbone", None) == LLAVA_NEXT: input['pixel_values'] = input['pixel_values'].squeeze(dim=1) input['image_sizes'] = input['image_sizes'].squeeze(dim=1) hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) hidden_states = hidden_states.hidden_states[-1] pooled_output = self._pooling(hidden_states, input['attention_mask']) return pooled_output elif getattr(self, "model_backbone", None) in [QWEN2_VL, QWEN2_5_VL, QWEN2_5_VL_TOKENSELECTION]: # 针对 Qwen2-VL/2.5-VL 的处理 # 确保传入 output_kwargs outputs = self.encoder(**input, **output_kwargs) # 获取 Pooled Embedding (保持原有逻辑) last_hidden_state = outputs.hidden_states[-1] pooled_output = self._pooling(last_hidden_state, input['attention_mask']) # --- 新增分析逻辑 --- analysis_stats = None if return_analysis and "input_ids" in input: # 假设使用 Qwen2-VL 的 image token id img_id = self.config.image_token_id if hasattr(self.config, 'image_token_id') else 151655 vid_id = self.config.video_token_id if hasattr(self.config, 'video_token_id') else 151656 analysis_stats = analyze_layer_metrics( outputs, input["input_ids"], image_token_id=img_id, video_token_id=vid_id ) if return_analysis: return pooled_output, analysis_stats return pooled_output else: hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) hidden_states = hidden_states.hidden_states[-1] pooled_output = self._pooling(hidden_states, input['attention_mask']) return pooled_output def _pooling(self, last_hidden_state, attention_mask): if self.pooling == 'last' or self.pooling == 'eos': left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) batch_size = last_hidden_state.shape[0] if left_padding: # Get the vectors at the last position reps = last_hidden_state[torch.arange(batch_size), -1, :] else: # Calculate last 1 position in the original tensor eos_indices = attention_mask.sum(dim=1) - 1 # Get the vectors at the last 1 position of each attention mask reps = last_hidden_state[ torch.arange(batch_size, device=last_hidden_state.device), eos_indices] else: raise NotImplementedError if self.normalize: reps = torch.nn.functional.normalize(reps, p=2, dim=-1) return reps @classmethod def build(cls, model_args: ModelArguments, **kwargs): config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) model_backbone = get_backbone_name(hf_config=config) print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}') # Loading the base model if model_backbone == PHI3V: config._attn_implementation = "eager" config.padding_side = "right" config.use_cache = False base_model = Phi3VForCausalLM.from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone == LLAVA_NEXT: config.use_cache = False config.padding_side = "left" base_model = LlavaNextForConditionalGeneration.from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone in [QWEN2_VL, QWEN2_5_VL]: config._attn_implementation = "flash_attention_2" config.padding_side = "left" config.use_cache = False base_model = backbone2model[model_backbone].from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) model.config.DART_config = dict( enabled=True, K=2, reduction_ratio=0.75, pivot_image_token=0, pivot_text_token=0, keep_pivot=True, sim="cosine", ) # print('here2') # exit() model.config.use_cache = False # 检索场景强烈建议 elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]: config._attn_implementation = "flash_attention_2" config.padding_side = "left" config.use_cache = False from .utils import parse_layer_type lm_qwen_layer = 28 vis_qwen_layer = 32 lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer) vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer) base_model = backbone2model[model_backbone].from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, lm_skip_layer=lm_skip_layer, vis_skip_layer=vis_skip_layer, ) else: config.use_cache = False base_model = cls.TRANSFORMER_CLS.from_pretrained( model_args.model_name, **kwargs, config=config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, trust_remote_code=True) if model_args.lora: print_master(f'Loading lora adapter from {base_model}') lora_config = LoraConfig( r=model_args.lora_r, lora_alpha=model_args.lora_alpha, target_modules=model_args.lora_target_modules.split(','), lora_dropout=model_args.lora_dropout, init_lora_weights="gaussian", use_dora=True, inference_mode=False ) lora_model = get_peft_model(base_model, lora_config) model = cls( encoder=lora_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) else: model = cls( encoder=base_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) return model @classmethod def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs): # Loading the base model model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) if not hasattr(model_args, "model_backbone") or not model_args.model_backbone: model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type) setattr(model_args, 'model_backbone', model_backbone) print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}') if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}: config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # =========== 关键修改:切勿强制 eager =========== # 注释掉下面这几行: # print_master("Warning: Forcing 'eager' attention implementation...") # config._attn_implementation = "eager" # config.vision_config._attn_implementation = "eager" # 保持使用 flash_attention_2 以确保稳定运行 config._attn_implementation = "flash_attention_2" config.vision_config._attn_implementation = "flash_attention_2" # ============================================ base_model = backbone2model[model_args.model_backbone].from_pretrained( model_args.model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, config=config ) base_model.config.DART_config = dict( enabled=True, K=2, reduction_ratio=0.75, pivot_image_token=0, pivot_text_token=0, keep_pivot=True, sim="cosine", ) # print('here3') # exit() base_model.config.use_cache = False # 检索场景强烈建议 elif model_args.model_backbone == PHI3V: config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False config.padding_side = "right" base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True) base_model.padding_side = "right" elif model_args.model_backbone == INTERNVIDEO2: print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}') config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", trust_remote_code=True) base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config, trust_remote_code=True) elif model_args.model_backbone == GME: base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor']) setattr(base_model, 'config', config) elif model_args.model_backbone == LamRA: base_model = LamRAQwen2VL(model_args.model_name) setattr(base_model, 'config', config) elif model_args.model_backbone == LamRA_QWEN2_5: base_model = LamRAQwen25VL(model_args.model_name) setattr(base_model, 'config', config) elif model_args.model_backbone == COLPALI: base_model = ColPali.from_pretrained(model_args.model_name) setattr(base_model, 'config', config) else: # Loading external base model from HF config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False base_model = cls.TRANSFORMER_CLS.from_pretrained( model_name_or_path, **kwargs, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True) # Building the model on top of the base if model_args.lora: print_master(f'Loading LoRA from {model_name_or_path}') lora_config = LoraConfig.from_pretrained(model_name_or_path) lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable) lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable) if not is_trainable: lora_model = lora_model.merge_and_unload() model = cls( encoder=lora_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) else: model = cls( encoder=base_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) model.model_backbone = model_args.model_backbone return model def save(self, output_dir: str): self.encoder.save_pretrained(output_dir) def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, return_analysis=False, *args, **kwargs): qry_out = self.encode_input(qry, return_analysis=return_analysis) if qry else None tgt_out = self.encode_input(tgt, return_analysis=return_analysis) if tgt else None # 解包分析结果 qry_reps, qry_stats = (qry_out[0], qry_out[1]) if (qry and return_analysis) else (qry_out, None) tgt_reps, tgt_stats = (tgt_out[0], tgt_out[1]) if (tgt and return_analysis) else (tgt_out, None) # 如果是 Eval 模式,通常只返回 reps,这里我们需要想办法把 stats 传出去 # 我们可以暂时把 stats 挂在 output 字典里 if return_analysis: return { "qry_reps": qry_reps, "tgt_reps": tgt_reps, "qry_stats": qry_stats, "tgt_stats": tgt_stats } if qry_reps is None or tgt_reps is None: return {"qry_reps": qry_reps, "tgt_reps": tgt_reps} if self.is_ddp: all_qry_reps = self._dist_gather_tensor(qry_reps) all_tgt_reps = self._dist_gather_tensor(tgt_reps) else: all_qry_reps = qry_reps all_tgt_reps = tgt_reps scores = self.compute_similarity(all_qry_reps, all_tgt_reps) scores = scores.view(all_qry_reps.size(0), -1) target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0)) loss = self.cross_entropy(scores / self.temperature, target) if self.is_ddp: loss = loss * self.world_size return loss def _dist_gather_tensor(self, t: Tensor): t = t.contiguous() all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] dist.all_gather(all_tensors, t) all_tensors[self.process_rank] = t all_tensors = torch.cat(all_tensors, dim=0) return all_tensors def compute_similarity(self, q_reps, p_reps): return torch.matmul(q_reps, p_reps.transpose(0, 1))