code_SAS_VLM2Vec / src /model /model_aop.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
from typing import Dict
import torch
import torch.distributed as dist
from torch import nn, Tensor
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig
from peft import LoraConfig, get_peft_model, PeftModel
from src.model.processor import QWEN2_5_VL_TOKENSELECTION
from src.arguments import ModelArguments, TrainingArguments
from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \
backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V
from src.arguments import ModelArguments
from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \
QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI
from src.model.baseline_backbone.colpali import ColPali
from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL
from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL
from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL
from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM
from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration
from transformers import modeling_utils
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise']
import torch.nn.functional as F
import numpy as np
# 新增:用于计算层级指标的辅助函数
def analyze_layer_metrics(outputs, input_ids, image_token_id=151655, video_token_id=151656):
# 将输入移至 CPU,避免 CUDA 索引问题,并切断梯度(detach)
input_ids_cpu = input_ids[0].detach().cpu() # 取 batch 的第一个样本
stats = {
"text_sim": [],
"attn_dist": [],
"text_influence": []
}
# 获取 hidden_states (Tuple of GPU Tensors)
hidden_states = outputs.hidden_states
# 简单的安全性检查
seq_len = len(input_ids_cpu)
if seq_len == 0:
return stats
# --- 在 CPU 上构建 Mask ---
# 注意:这里都在 CPU 上操作,完全不会触发 CUDA assert
is_vision = (input_ids_cpu == image_token_id) | (input_ids_cpu == video_token_id)
is_text = ~is_vision
# 排除 EOS (如果是最后一个 token)
if seq_len > 0:
is_text[seq_len-1] = False
# 遍历层 (排除最后一层)
for i, layer_hidden in enumerate(hidden_states[:-1]):
# 将当前层的 hidden state 移至 CPU
# layer_hidden 是 (Batch, Seq, Dim),我们取 [0] -> (Seq, Dim)
h = layer_hidden[0].detach().cpu().float()
h = F.normalize(h, p=2, dim=-1)
# 只选取文本 Token
text_h = h[is_text] # (Num_Text, D)
if text_h.shape[0] > 1:
# 计算余弦相似度矩阵 (CPU 运算)
sim_matrix = torch.matmul(text_h, text_h.T)
# 排除对角线
n_text = text_h.shape[0]
# 创建对角 mask
eye_mask = torch.eye(n_text, dtype=torch.bool)
mask = ~eye_mask
if mask.sum() > 0:
avg_sim = sim_matrix[mask].mean().item()
else:
avg_sim = 1.0
else:
avg_sim = 1.0
stats["text_sim"].append(avg_sim)
return stats
class MMEBModel(nn.Module):
TRANSFORMER_CLS = AutoModelForCausalLM
def __init__(self,
encoder: PreTrainedModel,
pooling: str = 'last',
normalize: bool = False,
temperature: float = 0.02,
):
super().__init__()
self.config = encoder.config
self.encoder = encoder
self.pooling = pooling
self.normalize = normalize
self.temperature = temperature
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.is_ddp = dist.is_initialized()
if self.is_ddp:
self.process_rank = dist.get_rank()
self.world_size = dist.get_world_size()
def encode_input(self, input, return_analysis=False):
# 强制开启 output_hidden_states 和 output_attentions (如果要做分析)
output_kwargs = {
"return_dict": True,
"output_hidden_states": True,
}
if getattr(self, "model_backbone", None) == INTERNVIDEO2:
if "input_ids" in input.keys():
# text side
text_output = self.encoder.get_text_encoder()(
input["input_ids"],
attention_mask=input["attention_mask"],
return_dict=True,
mode="text",
)
text_embeds = text_output.last_hidden_state
pooled_text_embeds = text_embeds[:, 0]
pooled_output = self.encoder.text_proj(pooled_text_embeds)
pooled_output /= pooled_output.norm(dim=-1, keepdim=True)
return pooled_output
else:
_, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True)
vfeat = self.encoder.vision_proj(vfeat)
vfeat /= vfeat.norm(dim=-1, keepdim=True)
return vfeat
elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]:
# pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen
images = []
for imgs in input['images']:
# if multi images are given, select the middle frame only
if isinstance(imgs, list):
imgs = imgs[len(imgs) // 2]
assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list
images.append(imgs)
else:
images.append(imgs)
pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images)
return pooled_output
elif getattr(self, "model_backbone", None) == COLPALI:
pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
return pooled_output
elif getattr(self, "model_backbone", None) == LLAVA_NEXT:
input['pixel_values'] = input['pixel_values'].squeeze(dim=1)
input['image_sizes'] = input['image_sizes'].squeeze(dim=1)
hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True)
hidden_states = hidden_states.hidden_states[-1]
pooled_output = self._pooling(hidden_states, input['attention_mask'])
return pooled_output
elif getattr(self, "model_backbone", None) in [QWEN2_VL, QWEN2_5_VL, QWEN2_5_VL_TOKENSELECTION]:
# 针对 Qwen2-VL/2.5-VL 的处理
# 确保传入 output_kwargs
outputs = self.encoder(**input, **output_kwargs)
# 获取 Pooled Embedding (保持原有逻辑)
last_hidden_state = outputs.hidden_states[-1]
pooled_output = self._pooling(last_hidden_state, input['attention_mask'])
# --- 新增分析逻辑 ---
analysis_stats = None
if return_analysis and "input_ids" in input:
# 假设使用 Qwen2-VL 的 image token id
img_id = self.config.image_token_id if hasattr(self.config, 'image_token_id') else 151655
vid_id = self.config.video_token_id if hasattr(self.config, 'video_token_id') else 151656
analysis_stats = analyze_layer_metrics(
outputs,
input["input_ids"],
image_token_id=img_id,
video_token_id=vid_id
)
if return_analysis:
return pooled_output, analysis_stats
return pooled_output
else:
hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True)
hidden_states = hidden_states.hidden_states[-1]
pooled_output = self._pooling(hidden_states, input['attention_mask'])
return pooled_output
def _pooling(self, last_hidden_state, attention_mask):
if self.pooling == 'last' or self.pooling == 'eos':
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
batch_size = last_hidden_state.shape[0]
if left_padding:
# Get the vectors at the last position
reps = last_hidden_state[torch.arange(batch_size), -1, :]
else:
# Calculate last 1 position in the original tensor
eos_indices = attention_mask.sum(dim=1) - 1
# Get the vectors at the last 1 position of each attention mask
reps = last_hidden_state[
torch.arange(batch_size, device=last_hidden_state.device), eos_indices]
else:
raise NotImplementedError
if self.normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
@classmethod
def build(cls, model_args: ModelArguments, **kwargs):
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
model_backbone = get_backbone_name(hf_config=config)
print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}')
# Loading the base model
if model_backbone == PHI3V:
config._attn_implementation = "eager"
config.padding_side = "right"
config.use_cache = False
base_model = Phi3VForCausalLM.from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone == LLAVA_NEXT:
config.use_cache = False
config.padding_side = "left"
base_model = LlavaNextForConditionalGeneration.from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in [QWEN2_VL, QWEN2_5_VL]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
model.config.DART_config = dict(
enabled=True,
K=2,
reduction_ratio=0.75,
pivot_image_token=0,
pivot_text_token=0,
keep_pivot=True,
sim="cosine",
)
# print('here2')
# exit()
model.config.use_cache = False # 检索场景强烈建议
elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
from .utils import parse_layer_type
lm_qwen_layer = 28
vis_qwen_layer = 32
lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer)
vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer)
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
lm_skip_layer=lm_skip_layer,
vis_skip_layer=vis_skip_layer,
)
else:
config.use_cache = False
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_args.model_name, **kwargs, config=config,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
trust_remote_code=True)
if model_args.lora:
print_master(f'Loading lora adapter from {base_model}')
lora_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
target_modules=model_args.lora_target_modules.split(','),
lora_dropout=model_args.lora_dropout,
init_lora_weights="gaussian",
use_dora=True,
inference_mode=False
)
lora_model = get_peft_model(base_model, lora_config)
model = cls(
encoder=lora_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
else:
model = cls(
encoder=base_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
return model
@classmethod
def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs):
# Loading the base model
model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
if not hasattr(model_args, "model_backbone") or not model_args.model_backbone:
model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}')
if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}:
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# =========== 关键修改:切勿强制 eager ===========
# 注释掉下面这几行:
# print_master("Warning: Forcing 'eager' attention implementation...")
# config._attn_implementation = "eager"
# config.vision_config._attn_implementation = "eager"
# 保持使用 flash_attention_2 以确保稳定运行
config._attn_implementation = "flash_attention_2"
config.vision_config._attn_implementation = "flash_attention_2"
# ============================================
base_model = backbone2model[model_args.model_backbone].from_pretrained(
model_args.model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
config=config
)
base_model.config.DART_config = dict(
enabled=True,
K=2,
reduction_ratio=0.75,
pivot_image_token=0,
pivot_text_token=0,
keep_pivot=True,
sim="cosine",
)
# print('here3')
# exit()
base_model.config.use_cache = False # 检索场景强烈建议
elif model_args.model_backbone == PHI3V:
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config.use_cache = False
config.padding_side = "right"
base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config,
torch_dtype=torch.bfloat16, trust_remote_code=True)
base_model.padding_side = "right"
elif model_args.model_backbone == INTERNVIDEO2:
print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}')
config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/",
trust_remote_code=True)
base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config,
trust_remote_code=True)
elif model_args.model_backbone == GME:
base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor'])
setattr(base_model, 'config', config)
elif model_args.model_backbone == LamRA:
base_model = LamRAQwen2VL(model_args.model_name)
setattr(base_model, 'config', config)
elif model_args.model_backbone == LamRA_QWEN2_5:
base_model = LamRAQwen25VL(model_args.model_name)
setattr(base_model, 'config', config)
elif model_args.model_backbone == COLPALI:
base_model = ColPali.from_pretrained(model_args.model_name)
setattr(base_model, 'config', config)
else:
# Loading external base model from HF
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config.use_cache = False
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_name_or_path, **kwargs, config=config,
torch_dtype=torch.bfloat16,
trust_remote_code=True)
# Building the model on top of the base
if model_args.lora:
print_master(f'Loading LoRA from {model_name_or_path}')
lora_config = LoraConfig.from_pretrained(model_name_or_path)
lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable)
lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable)
if not is_trainable:
lora_model = lora_model.merge_and_unload()
model = cls(
encoder=lora_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
else:
model = cls(
encoder=base_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
model.model_backbone = model_args.model_backbone
return model
def save(self, output_dir: str):
self.encoder.save_pretrained(output_dir)
def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, return_analysis=False, *args, **kwargs):
qry_out = self.encode_input(qry, return_analysis=return_analysis) if qry else None
tgt_out = self.encode_input(tgt, return_analysis=return_analysis) if tgt else None
# 解包分析结果
qry_reps, qry_stats = (qry_out[0], qry_out[1]) if (qry and return_analysis) else (qry_out, None)
tgt_reps, tgt_stats = (tgt_out[0], tgt_out[1]) if (tgt and return_analysis) else (tgt_out, None)
# 如果是 Eval 模式,通常只返回 reps,这里我们需要想办法把 stats 传出去
# 我们可以暂时把 stats 挂在 output 字典里
if return_analysis:
return {
"qry_reps": qry_reps,
"tgt_reps": tgt_reps,
"qry_stats": qry_stats,
"tgt_stats": tgt_stats
}
if qry_reps is None or tgt_reps is None:
return {"qry_reps": qry_reps, "tgt_reps": tgt_reps}
if self.is_ddp:
all_qry_reps = self._dist_gather_tensor(qry_reps)
all_tgt_reps = self._dist_gather_tensor(tgt_reps)
else:
all_qry_reps = qry_reps
all_tgt_reps = tgt_reps
scores = self.compute_similarity(all_qry_reps, all_tgt_reps)
scores = scores.view(all_qry_reps.size(0), -1)
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0))
loss = self.cross_entropy(scores / self.temperature, target)
if self.is_ddp:
loss = loss * self.world_size
return loss
def _dist_gather_tensor(self, t: Tensor):
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)
all_tensors[self.process_rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
def compute_similarity(self, q_reps, p_reps):
return torch.matmul(q_reps, p_reps.transpose(0, 1))