code_SAS_VLM2Vec / src /model /model_layer_prune_add_mlp.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
from typing import Dict
import torch, os
import torch.distributed as dist
from torch import nn, Tensor
import torch.nn.functional as F # 如果文件顶部没引入的话
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig
from peft import LoraConfig, get_peft_model, PeftModel
from src.model.processor import QWEN2_5_VL_TOKENSELECTION
from src.arguments import ModelArguments, TrainingArguments
from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \
backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V
from src.arguments import ModelArguments
from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \
QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI
from src.model.baseline_backbone.colpali import ColPali
from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL
from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL
from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL
from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM
from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration
from transformers import modeling_utils
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise']
class MMEBModel(nn.Module):
TRANSFORMER_CLS = AutoModelForCausalLM
def __init__(self,
encoder: PreTrainedModel,
pooling: str = 'last',
normalize: bool = False,
temperature: float = 0.02,
):
super().__init__()
self.config = encoder.config
self.encoder = encoder
self.pooling = pooling
self.normalize = normalize
self.temperature = temperature
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.is_ddp = dist.is_initialized()
if self.is_ddp:
self.process_rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.layer_indices = [20, -1]
self.dual_layer_idx = 20 # query 的第20层
self.dual_alpha = 0.05 # 两个 CE 的加权系数
# [新增] 20层投影头(D->D),用于20层对比表示
hidden_size = getattr(self.config, "hidden_size", None)
if hidden_size is None:
# 某些backbone字段名不同,尽量兜底
hidden_size = getattr(self.encoder.config, "hidden_size", 1024)
self.proj20 = nn.Sequential(
nn.LayerNorm(hidden_size),
nn.Linear(hidden_size, hidden_size),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
)
# [新增] 两阶段训练开关:1=仅20层MLP;2=20层MLP+最后一层,训练全模型
self.training_stage = 1
def _is_qwen2_series(self):
return getattr(self, "model_backbone", None) in {
QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION
}
def _squeeze_mm_inputs(self, inp: Dict) -> Dict:
# 避免原地修改
x = dict(inp)
if "pixel_values" in x and isinstance(x["pixel_values"], torch.Tensor) and x["pixel_values"].dim() == 5:
# [B, 1, 3, H, W] -> [B, 3, H, W]
x["pixel_values"] = x["pixel_values"].squeeze(1)
if "image_sizes" in x and isinstance(x["image_sizes"], torch.Tensor) and x["image_sizes"].dim() >= 3:
# [B, 1, 2] or [B, 1, ...] -> squeeze 第2维
x["image_sizes"] = x["image_sizes"].squeeze(1)
return x
# [新增] 设置训练阶段,并可选冻结/解冻
def set_training_stage(self, stage: int, freeze_encoder: bool = True, verbose: bool = True):
"""
stage=1: 仅使用20层(加MLP)的loss,并仅训练MLP(默认冻结encoder)
stage=2: 使用20层(加MLP)+最后一层的loss,训练整个模型
注意:需在创建优化器之前调用,或在阶段切换后重新创建优化器。
"""
assert stage in (1, 2), "stage 只能为 1 或 2"
self.training_stage = stage
if freeze_encoder:
if stage == 1:
# 冻结除投影头外的参数
for p in self.parameters():
p.requires_grad = False
for p in self.proj20.parameters():
p.requires_grad = True
if verbose:
print("[MMEB] Stage 1: 冻结 encoder,仅训练 proj20")
else:
# 解冻全部
for p in self.parameters():
p.requires_grad = True
if verbose:
print("[MMEB] Stage 2: 训练全模型(含 proj20 和 encoder)")
# [新增] 仅取第20层的池化向量(不投影)
# def _encode_20_raw(self, input):
# mb = getattr(self, "model_backbone", None)
# # 支持 hidden_states 的默认分支
# if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI, LLAVA_NEXT]:
# out = self.encoder(**input, return_dict=True, output_hidden_states=True)
# hs = out.hidden_states # [emb, layer1, ..., layerL]
# idx20 = self.dual_layer_idx
# if idx20 < 0:
# idx20 = len(hs) + idx20
# idx20 = max(1, min(idx20, len(hs) - 1))
# rep20 = self._pooling(hs[idx20], input['attention_mask'])
# return rep20
# if mb == LLAVA_NEXT:
# input = dict(input) # 避免原地改
# input['pixel_values'] = input['pixel_values'].squeeze(dim=1)
# input['image_sizes'] = input['image_sizes'].squeeze(dim=1)
# out = self.encoder(**input, return_dict=True, output_hidden_states=True)
# hs = out.hidden_states
# idx20 = self.dual_layer_idx
# if idx20 < 0:
# idx20 = len(hs) + idx20
# idx20 = max(1, min(idx20, len(hs) - 1))
# rep20 = self._pooling(hs[idx20], input['attention_mask'])
# return rep20
# # 其他backbone不支持中间层:回退(用最后一层代替20层)
# last = self.encode_input(input) # [B, D]
# return last
def _encode_20_raw(self, input):
"""
取第20层的池化向量(不投影)。对需要的骨干做必要的输入整形。
"""
mb = getattr(self, "model_backbone", None)
idx20 = int(getattr(self, "dual_layer_idx", 20))
# 1) 需要特殊处理图像维度的骨干:LLAVA_NEXT + QWEN2 系列
if mb in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION}:
inp = dict(input) # 避免原地改
if "pixel_values" in inp and isinstance(inp["pixel_values"], torch.Tensor) and inp["pixel_values"].dim() == 5:
# [B, 1, 3, H, W] -> [B, 3, H, W]
inp["pixel_values"] = inp["pixel_values"].squeeze(1)
if "image_sizes" in inp and isinstance(inp["image_sizes"], torch.Tensor) and inp["image_sizes"].dim() >= 3:
# [B, 1, 2] or [B, 1, 2, ...] -> squeeze 第2维
inp["image_sizes"] = inp["image_sizes"].squeeze(1)
out = self.encoder(**inp, return_dict=True, output_hidden_states=True)
hs = getattr(out, "hidden_states", None)
if hs is None:
# 某些实现可能有不同字段名,兜底尝试(通常不会走到)
if hasattr(out, "text_hidden_states"):
hs = out.text_hidden_states
else:
raise RuntimeError("hidden_states is None; ensure output_hidden_states=True and trust_remote_code=True.")
# hidden_states[0] 是 embeddings,1..L 是每层输出
L = len(hs) - 1
if idx20 < 0:
layer_idx = idx20
else:
layer_idx = max(1, min(idx20, L)) if L >= 1 else -1
rep20 = self._pooling(hs[layer_idx], inp["attention_mask"])
return rep20
# 2) 不支持中间层的骨干:退化为“最后一层”
if mb in {INTERNVIDEO2, GME, LamRA, LamRA_QWEN2_5, COLPALI}:
last = self.encode_input(input) # [B, D]
return last
# 3) 其他(默认 HF CausalLM 等,支持 hidden_states)
out = self.encoder(**input, return_dict=True, output_hidden_states=True)
hs = out.hidden_states
L = len(hs) - 1
if idx20 < 0:
layer_idx = idx20
else:
layer_idx = max(1, min(idx20, L)) if L >= 1 else -1
rep20 = self._pooling(hs[layer_idx], input["attention_mask"])
return rep20
# [新增] 20层 + MLP 投影 + 可选归一化
def _encode_20_proj(self, input):
rep20 = self._encode_20_raw(input) # [B, D]
rep20 = self.proj20(rep20) # [B, D]
if self.normalize:
rep20 = F.normalize(rep20, p=2, dim=-1)
return rep20
# def _encode_query_dual(self, input):
# """
# 返回 [B, 2, D]: 第20层与最后一层的池化向量。
# 对不支持 hidden_states 的 backbone,回退为两份相同的最后一层。
# """
# mb = getattr(self, "model_backbone", None)
# def norm(x):
# return F.normalize(x, p=2, dim=-1) if self.normalize else x
# # 支持 hidden_states 的通用分支(默认HF)
# if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI, LLAVA_NEXT]:
# out = self.encoder(**input, return_dict=True, output_hidden_states=True)
# hs = out.hidden_states # [emb, layer1, ..., layerL]
# idx20 = self.dual_layer_idx
# if idx20 < 0:
# idx20 = len(hs) + idx20 # 允许负索引
# idx20 = max(1, min(idx20, len(hs) - 1)) # 1..L
# rep20 = self._pooling(hs[idx20], input['attention_mask'])
# replast = self._pooling(hs[-1], input['attention_mask'])
# rep20 = self.proj20(rep20) # [修改] 20层经过投影头
# rep20, replast = norm(rep20), norm(replast)
# reps = torch.stack([rep20, replast], dim=1) # [B, 2, D]
# return reps
# # [修改] LLAVA_NEXT 分支:支持两层
# if mb == LLAVA_NEXT:
# input = dict(input) # 避免原地修改
# input['pixel_values'] = input['pixel_values'].squeeze(dim=1)
# input['image_sizes'] = input['image_sizes'].squeeze(dim=1)
# out = self.encoder(**input, return_dict=True, output_hidden_states=True)
# hs = out.hidden_states
# idx20 = self.dual_layer_idx
# if idx20 < 0:
# idx20 = len(hs) + idx20
# idx20 = max(1, min(idx20, len(hs) - 1))
# rep20 = self._pooling(hs[idx20], input['attention_mask'])
# replast = self._pooling(hs[-1], input['attention_mask'])
# rep20 = self.proj20(rep20) # [修改] 20层经过投影头
# if self.normalize:
# rep20 = F.normalize(rep20, p=2, dim=-1)
# replast = F.normalize(replast, p=2, dim=-1)
# reps = torch.stack([rep20, replast], dim=1)
# return reps
# # 回退:两份最后一层,其中第一份过 MLP 以保持接口一致
# last = self.encode_input(input) # [B, D]
# rep20 = self.proj20(last)
# if self.normalize:
# rep20 = F.normalize(rep20, p=2, dim=-1)
# last = F.normalize(last, p=2, dim=-1)
# return torch.stack([rep20, last], dim=1)
def _encode_query_dual(self, input):
"""
返回 [B, 2, D]: 第20层(过MLP) 与 最后一层 的池化向量。
对不支持 hidden_states 的 backbone,回退为两份相同的最后一层(第一份过MLP)。
"""
mb = getattr(self, "model_backbone", None)
def norm(x):
return F.normalize(x, p=2, dim=-1) if self.normalize else x
# 支持 hidden_states 的通用分支(默认HF)
if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI, LLAVA_NEXT] and not self._is_qwen2_series():
out = self.encoder(**input, return_dict=True, output_hidden_states=True)
hs = out.hidden_states # [emb, layer1, ..., layerL]
idx20 = self.dual_layer_idx
if idx20 < 0:
idx20 = len(hs) + idx20
idx20 = max(1, min(idx20, len(hs) - 1))
rep20 = self._pooling(hs[idx20], input['attention_mask'])
replast = self._pooling(hs[-1], input['attention_mask'])
rep20 = self.proj20(rep20)
rep20, replast = norm(rep20), norm(replast)
return torch.stack([rep20, replast], dim=1) # [B, 2, D]
# LLAVA_NEXT + QWEN2 系列:先 squeeze 再前向
if mb == LLAVA_NEXT or self._is_qwen2_series():
inp = self._squeeze_mm_inputs(input)
out = self.encoder(**inp, return_dict=True, output_hidden_states=True)
hs = out.hidden_states
idx20 = self.dual_layer_idx
if idx20 < 0:
idx20 = len(hs) + idx20
idx20 = max(1, min(idx20, len(hs) - 1))
rep20 = self._pooling(hs[idx20], inp['attention_mask'])
replast = self._pooling(hs[-1], inp['attention_mask'])
rep20 = self.proj20(rep20)
rep20, replast = norm(rep20), norm(replast)
return torch.stack([rep20, replast], dim=1)
# 回退:两份最后一层,其中第一份过 MLP 以保持接口一致
last = self.encode_input(input) # [B, D]
rep20 = self.proj20(last)
if self.normalize:
rep20 = F.normalize(rep20, p=2, dim=-1)
last = F.normalize(last, p=2, dim=-1)
return torch.stack([rep20, last], dim=1)
# def _encode_target_dual(self, input):
# """
# 返回 [B, 2, D]: cand 的第20层与最后一层池化向量。
# 对不支持 hidden_states 的 backbone,回退为两份相同的最后一层。
# """
# mb = getattr(self, "model_backbone", None)
# def norm(x):
# return F.normalize(x, p=2, dim=-1) if self.normalize else x
# if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI, LLAVA_NEXT]:
# out = self.encoder(**input, return_dict=True, output_hidden_states=True)
# hs = out.hidden_states
# idx20 = self.dual_layer_idx
# if idx20 < 0:
# idx20 = len(hs) + idx20
# idx20 = max(1, min(idx20, len(hs) - 1))
# rep20 = self._pooling(hs[idx20], input['attention_mask'])
# replast = self._pooling(hs[-1], input['attention_mask'])
# rep20 = self.proj20(rep20) # [修改]
# rep20, replast = norm(rep20), norm(replast)
# reps = torch.stack([rep20, replast], dim=1)
# return reps
# if mb == LLAVA_NEXT:
# input = dict(input)
# input['pixel_values'] = input['pixel_values'].squeeze(dim=1)
# input['image_sizes'] = input['image_sizes'].squeeze(dim=1)
# out = self.encoder(**input, return_dict=True, output_hidden_states=True)
# hs = out.hidden_states
# idx20 = self.dual_layer_idx
# if idx20 < 0:
# idx20 = len(hs) + idx20
# idx20 = max(1, min(idx20, len(hs) - 1))
# rep20 = self._pooling(hs[idx20], input['attention_mask'])
# replast = self._pooling(hs[-1], input['attention_mask'])
# rep20 = self.proj20(rep20) # [修改]
# if self.normalize:
# rep20 = F.normalize(rep20, p=2, dim=-1)
# replast = F.normalize(replast, p=2, dim=-1)
# reps = torch.stack([rep20, replast], dim=1)
# return reps
# last = self.encode_input(input) # [B, D]
# rep20 = self.proj20(last)
# if self.normalize:
# rep20 = F.normalize(rep20, p=2, dim=-1)
# last = F.normalize(last, p=2, dim=-1)
# return torch.stack([rep20, last], dim=1)
def _encode_target_dual(self, input):
"""
返回 [B, 2, D]: cand 的第20层(过MLP) 与 最后一层 池化向量。
对不支持 hidden_states 的 backbone,回退为两份相同的最后一层(第一份过MLP)。
"""
mb = getattr(self, "model_backbone", None)
def norm(x):
return F.normalize(x, p=2, dim=-1) if self.normalize else x
if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI, LLAVA_NEXT] and not self._is_qwen2_series():
out = self.encoder(**input, return_dict=True, output_hidden_states=True)
hs = out.hidden_states
idx20 = self.dual_layer_idx
if idx20 < 0:
idx20 = len(hs) + idx20
idx20 = max(1, min(idx20, len(hs) - 1))
rep20 = self._pooling(hs[idx20], input['attention_mask'])
replast = self._pooling(hs[-1], input['attention_mask'])
rep20 = self.proj20(rep20)
rep20, replast = norm(rep20), norm(replast)
return torch.stack([rep20, replast], dim=1)
if mb == LLAVA_NEXT or self._is_qwen2_series():
inp = self._squeeze_mm_inputs(input)
out = self.encoder(**inp, return_dict=True, output_hidden_states=True)
hs = out.hidden_states
idx20 = self.dual_layer_idx
if idx20 < 0:
idx20 = len(hs) + idx20
idx20 = max(1, min(idx20, len(hs) - 1))
rep20 = self._pooling(hs[idx20], inp['attention_mask'])
replast = self._pooling(hs[-1], inp['attention_mask'])
rep20 = self.proj20(rep20)
rep20, replast = norm(rep20), norm(replast)
return torch.stack([rep20, replast], dim=1)
last = self.encode_input(input) # [B, D]
rep20 = self.proj20(last)
if self.normalize:
rep20 = F.normalize(rep20, p=2, dim=-1)
last = F.normalize(last, p=2, dim=-1)
return torch.stack([rep20, last], dim=1)
# def encode_input(self, input):
# def encode_input(self, input, layer_indices=None):
# if getattr(self, "model_backbone", None) == INTERNVIDEO2:
# if "input_ids" in input.keys():
# # text side
# text_output = self.encoder.get_text_encoder()(
# input["input_ids"],
# attention_mask=input["attention_mask"],
# return_dict=True,
# mode="text",
# )
# text_embeds = text_output.last_hidden_state
# pooled_text_embeds = text_embeds[:, 0]
# pooled_output = self.encoder.text_proj(pooled_text_embeds)
# pooled_output /= pooled_output.norm(dim=-1, keepdim=True)
# return pooled_output
# else:
# _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True)
# vfeat = self.encoder.vision_proj(vfeat)
# vfeat /= vfeat.norm(dim=-1, keepdim=True)
# return vfeat
# elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]:
# # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
# texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen
# images = []
# for imgs in input['images']:
# # if multi images are given, select the middle frame only
# if isinstance(imgs, list):
# imgs = imgs[len(imgs) // 2]
# assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list
# images.append(imgs)
# else:
# images.append(imgs)
# pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images)
# return pooled_output
# elif getattr(self, "model_backbone", None) == COLPALI:
# pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
# return pooled_output
# elif getattr(self, "model_backbone", None) == LLAVA_NEXT:
# input['pixel_values'] = input['pixel_values'].squeeze(dim=1)
# input['image_sizes'] = input['image_sizes'].squeeze(dim=1)
# hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True)
# hidden_states = hidden_states.hidden_states[-1]
# pooled_output = self._pooling(hidden_states, input['attention_mask'])
# return pooled_output
# else:
# # hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True)
# # # hidden_states = self.encoder(**input, compression_rate=compression_rate, return_dict=True, output_hidden_states=True)
# # hidden_states = hidden_states.hidden_states[-1]
# # pooled_output = self._pooling(hidden_states, input['attention_mask'])
# # return pooled_output
# # 默认HF模型:支持 hidden_states
# out = self.encoder(**input, return_dict=True, output_hidden_states=True)
# hs_list = out.hidden_states
# if layer_indices is None or isinstance(layer_indices, int):
# h = hs_list[-1] if layer_indices is None else hs_list[layer_indices]
# reps = self._pooling(h, input['attention_mask'])
# return reps
# else:
# reps_list = []
# for idx in layer_indices:
# h = hs_list[idx]
# r = self._pooling(h, input['attention_mask'])
# reps_list.append(r)
# reps = torch.stack(reps_list, dim=1) # [B, L, D]
# return reps
def encode_input(self, input, layer_indices=None):
mb = getattr(self, "model_backbone", None)
if mb == INTERNVIDEO2:
if "input_ids" in input.keys():
text_output = self.encoder.get_text_encoder()(
input["input_ids"],
attention_mask=input["attention_mask"],
return_dict=True,
mode="text",
)
text_embeds = text_output.last_hidden_state
pooled_text_embeds = text_embeds[:, 0]
pooled_output = self.encoder.text_proj(pooled_text_embeds)
pooled_output /= pooled_output.norm(dim=-1, keepdim=True)
return pooled_output
else:
_, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True)
vfeat = self.encoder.vision_proj(vfeat)
vfeat /= vfeat.norm(dim=-1, keepdim=True)
return vfeat
elif mb in [GME, LamRA, LamRA_QWEN2_5]:
texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]]
images = []
for imgs in input['images']:
if isinstance(imgs, list):
imgs = imgs[len(imgs) // 2]
assert not isinstance(imgs, list)
images.append(imgs)
else:
images.append(imgs)
pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images)
return pooled_output
elif mb == COLPALI:
pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
return pooled_output
elif mb == LLAVA_NEXT or self._is_qwen2_series():
inp = self._squeeze_mm_inputs(input)
out = self.encoder(**inp, return_dict=True, output_hidden_states=True)
h_last = out.hidden_states[-1]
pooled_output = self._pooling(h_last, inp['attention_mask'])
return pooled_output
else:
out = self.encoder(**input, return_dict=True, output_hidden_states=True)
hs_list = out.hidden_states
if layer_indices is None or isinstance(layer_indices, int):
h = hs_list[-1] if layer_indices is None else hs_list[layer_indices]
reps = self._pooling(h, input['attention_mask'])
return reps
else:
reps_list = []
for idx in layer_indices:
h = hs_list[idx]
r = self._pooling(h, input['attention_mask'])
reps_list.append(r)
return torch.stack(reps_list, dim=1) # [B, V, D]
def _pooling(self, last_hidden_state, attention_mask):
if self.pooling == 'last' or self.pooling == 'eos':
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
batch_size = last_hidden_state.shape[0]
if left_padding:
# Get the vectors at the last position
reps = last_hidden_state[torch.arange(batch_size), -1, :]
else:
# Calculate last 1 position in the original tensor
eos_indices = attention_mask.sum(dim=1) - 1
# Get the vectors at the last 1 position of each attention mask
reps = last_hidden_state[
torch.arange(batch_size, device=last_hidden_state.device), eos_indices]
else:
raise NotImplementedError
if self.normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
@classmethod
def build(cls, model_args: ModelArguments, **kwargs):
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
variant = getattr(config, "backbone_variant", None)
if variant == "layerprune":
model_backbone = "QWEN2_VL_LayerPrune"
else:
model_backbone = get_backbone_name(hf_config=config)
print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}')
# Loading the base model
if model_backbone == PHI3V:
config._attn_implementation = "eager"
config.padding_side = "right"
config.use_cache = False
base_model = Phi3VForCausalLM.from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone == LLAVA_NEXT:
config.use_cache = False
config.padding_side = "left"
base_model = LlavaNextForConditionalGeneration.from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in [QWEN2_VL, QWEN2_5_VL]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in ["QWEN2_VL_LayerPrune"]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
from .utils import parse_layer_type
lm_qwen_layer = 28
vis_qwen_layer = 32
lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer)
vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer)
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
lm_skip_layer=lm_skip_layer,
vis_skip_layer=vis_skip_layer,
)
else:
config.use_cache = False
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_args.model_name, **kwargs, config=config,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
trust_remote_code=True)
if model_args.lora:
print_master(f'Loading lora adapter from {base_model}')
lora_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
target_modules=model_args.lora_target_modules.split(','),
lora_dropout=model_args.lora_dropout,
init_lora_weights="gaussian",
use_dora=True,
inference_mode=False
)
lora_model = get_peft_model(base_model, lora_config)
model = cls(
encoder=lora_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
else:
model = cls(
encoder=base_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
return model
@classmethod
def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs):
# Loading the base model
model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
if not hasattr(model_args, "model_backbone") or not model_args.model_backbone:
model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}')
if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}:
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config._attn_implementation = "flash_attention_2"
config.vision_config._attn_implementation = "flash_attention_2"
base_model = backbone2model[model_args.model_backbone].from_pretrained(
model_args.model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
config=config
)
elif model_args.model_backbone == PHI3V:
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config.use_cache = False
config.padding_side = "right"
base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config,
torch_dtype=torch.bfloat16, trust_remote_code=True)
base_model.padding_side = "right"
elif model_args.model_backbone == INTERNVIDEO2:
print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}')
config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/",
trust_remote_code=True)
base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config,
trust_remote_code=True)
elif model_args.model_backbone == GME:
base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor'])
setattr(base_model, 'config', config)
elif model_args.model_backbone == LamRA:
base_model = LamRAQwen2VL(model_args.model_name)
setattr(base_model, 'config', config)
elif model_args.model_backbone == LamRA_QWEN2_5:
base_model = LamRAQwen25VL(model_args.model_name)
setattr(base_model, 'config', config)
elif model_args.model_backbone == COLPALI:
base_model = ColPali.from_pretrained(model_args.model_name)
setattr(base_model, 'config', config)
else:
# Loading external base model from HF
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config.use_cache = False
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_name_or_path, **kwargs, config=config,
torch_dtype=torch.bfloat16,
trust_remote_code=True)
# Building the model on top of the base
if model_args.lora:
print_master(f'Loading LoRA from {model_name_or_path}')
lora_config = LoraConfig.from_pretrained(model_name_or_path)
lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable)
lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable)
if not is_trainable:
lora_model = lora_model.merge_and_unload()
model = cls(
encoder=lora_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
else:
model = cls(
encoder=base_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
model.model_backbone = model_args.model_backbone
try:
ckpt_dir = model_args.checkpoint_path or model_args.model_name
extra_candidates = [
os.path.join(ckpt_dir, "mmeb_extra.pt"),
os.path.join(ckpt_dir, "mmeb_extra.bin"),
os.path.join(ckpt_dir, "extra_heads.pt"),
os.path.join(ckpt_dir, "proj20.pt"),
]
extra_path = next((p for p in extra_candidates if os.path.isfile(p)), None)
if extra_path:
extra_sd = torch.load(extra_path, map_location="cpu")
missing, unexpected = model.load_state_dict(extra_sd, strict=False)
print_master(f"Loaded extra heads from {extra_path}. "
f"missing={len(missing)}, unexpected={len(unexpected)}")
except Exception as e:
print_master(f"[WARN] Failed to load extra heads: {e}")
return model
def save(self, output_dir: str):
self.encoder.save_pretrained(output_dir)
def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs):
# GradCache:只给一侧,返回表示
if qry is not None and tgt is None:
if self.training_stage == 1:
# [修改] 阶段1:仅20层MLP,返回 [B, D]
qry_reps = self._encode_20_proj(qry)
else:
# [修改] 阶段2:返回 [B, 2, D](20层MLP + 最后一层)
qry_reps = self._encode_query_dual(qry)
return {"qry_reps": qry_reps, "tgt_reps": None}
if tgt is not None and qry is None:
if self.training_stage == 1:
tgt_reps = self._encode_20_proj(tgt) # [B, D]
else:
tgt_reps = self._encode_target_dual(tgt) # [B, 2, D]
return {"qry_reps": None, "tgt_reps": tgt_reps}
# 非 GradCache:两侧同时给,直接算损失
if qry is not None and tgt is not None:
if self.training_stage == 1:
# [修改] 階段1:仅20层MLP的单视角 InfoNCE
q = self._encode_20_proj(qry) # [B, D]
t = self._encode_20_proj(tgt) # [B, D]
if self.is_ddp:
q = self._dist_gather_tensor(q)
t = self._dist_gather_tensor(t)
logits = torch.matmul(q, t.transpose(0, 1)) / self.temperature
target = torch.arange(logits.size(0), device=logits.device, dtype=torch.long)
loss = self.cross_entropy(logits, target)
if self.is_ddp:
loss = loss * self.world_size
return loss
else:
# [修改] 階段2:匹配视角(20↔20 + last↔last),加权求和
q = self._encode_query_dual(qry) # [B, 2, D]
t = self._encode_target_dual(tgt) # [B, 2, D]
if self.is_ddp:
q = self._dist_gather_tensor(q)
t = self._dist_gather_tensor(t)
B = q.size(0)
labels = torch.arange(B, device=q.device, dtype=torch.long)
alpha = getattr(self, "dual_alpha", 0.2)
# v=0: 20层;v=1: 最后一层
loss20 = self.cross_entropy((q[:, 0, :] @ t[:, 0, :].T) / self.temperature, labels)
lossL = self.cross_entropy((q[:, 1, :] @ t[:, 1, :].T) / self.temperature, labels)
loss = alpha * loss20 + (1.0 - alpha) * lossL
if self.is_ddp:
loss = loss * self.world_size
return loss
return {"qry_reps": None, "tgt_reps": None}
def _dist_gather_tensor(self, t: Tensor):
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)
all_tensors[self.process_rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
def compute_similarity(self, q_reps, p_reps):
return torch.matmul(q_reps, p_reps.transpose(0, 1))