from typing import Dict import torch, os import torch.distributed as dist from torch import nn, Tensor import torch.nn.functional as F # 如果文件顶部没引入的话 from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig from peft import LoraConfig, get_peft_model, PeftModel from src.model.processor import QWEN2_5_VL_TOKENSELECTION from src.arguments import ModelArguments, TrainingArguments from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \ backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V from src.arguments import ModelArguments from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \ QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI from src.model.baseline_backbone.colpali import ColPali from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration from transformers import modeling_utils if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise'] class MMEBModel(nn.Module): TRANSFORMER_CLS = AutoModelForCausalLM def __init__(self, encoder: PreTrainedModel, pooling: str = 'last', normalize: bool = False, temperature: float = 0.02, ): super().__init__() self.config = encoder.config self.encoder = encoder self.pooling = pooling self.normalize = normalize self.temperature = temperature self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') self.is_ddp = dist.is_initialized() if self.is_ddp: self.process_rank = dist.get_rank() self.world_size = dist.get_world_size() self.layer_indices = [20, -1] self.dual_layer_idx = 20 # query 的第20层 self.dual_alpha = 0.05 # 两个 CE 的加权系数 # [新增] 20层投影头(D->D),用于20层对比表示 hidden_size = getattr(self.config, "hidden_size", None) if hidden_size is None: # 某些backbone字段名不同,尽量兜底 hidden_size = getattr(self.encoder.config, "hidden_size", 1024) self.proj20 = nn.Sequential( nn.LayerNorm(hidden_size), nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Dropout(0.1), nn.Linear(hidden_size, hidden_size), ) # [新增] 两阶段训练开关:1=仅20层MLP;2=20层MLP+最后一层,训练全模型 self.training_stage = 1 def _is_qwen2_series(self): return getattr(self, "model_backbone", None) in { QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION } def _squeeze_mm_inputs(self, inp: Dict) -> Dict: # 避免原地修改 x = dict(inp) if "pixel_values" in x and isinstance(x["pixel_values"], torch.Tensor) and x["pixel_values"].dim() == 5: # [B, 1, 3, H, W] -> [B, 3, H, W] x["pixel_values"] = x["pixel_values"].squeeze(1) if "image_sizes" in x and isinstance(x["image_sizes"], torch.Tensor) and x["image_sizes"].dim() >= 3: # [B, 1, 2] or [B, 1, ...] -> squeeze 第2维 x["image_sizes"] = x["image_sizes"].squeeze(1) return x # [新增] 设置训练阶段,并可选冻结/解冻 def set_training_stage(self, stage: int, freeze_encoder: bool = True, verbose: bool = True): """ stage=1: 仅使用20层(加MLP)的loss,并仅训练MLP(默认冻结encoder) stage=2: 使用20层(加MLP)+最后一层的loss,训练整个模型 注意:需在创建优化器之前调用,或在阶段切换后重新创建优化器。 """ assert stage in (1, 2), "stage 只能为 1 或 2" self.training_stage = stage if freeze_encoder: if stage == 1: # 冻结除投影头外的参数 for p in self.parameters(): p.requires_grad = False for p in self.proj20.parameters(): p.requires_grad = True if verbose: print("[MMEB] Stage 1: 冻结 encoder,仅训练 proj20") else: # 解冻全部 for p in self.parameters(): p.requires_grad = True if verbose: print("[MMEB] Stage 2: 训练全模型(含 proj20 和 encoder)") # [新增] 仅取第20层的池化向量(不投影) # def _encode_20_raw(self, input): # mb = getattr(self, "model_backbone", None) # # 支持 hidden_states 的默认分支 # if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI, LLAVA_NEXT]: # out = self.encoder(**input, return_dict=True, output_hidden_states=True) # hs = out.hidden_states # [emb, layer1, ..., layerL] # idx20 = self.dual_layer_idx # if idx20 < 0: # idx20 = len(hs) + idx20 # idx20 = max(1, min(idx20, len(hs) - 1)) # rep20 = self._pooling(hs[idx20], input['attention_mask']) # return rep20 # if mb == LLAVA_NEXT: # input = dict(input) # 避免原地改 # input['pixel_values'] = input['pixel_values'].squeeze(dim=1) # input['image_sizes'] = input['image_sizes'].squeeze(dim=1) # out = self.encoder(**input, return_dict=True, output_hidden_states=True) # hs = out.hidden_states # idx20 = self.dual_layer_idx # if idx20 < 0: # idx20 = len(hs) + idx20 # idx20 = max(1, min(idx20, len(hs) - 1)) # rep20 = self._pooling(hs[idx20], input['attention_mask']) # return rep20 # # 其他backbone不支持中间层:回退(用最后一层代替20层) # last = self.encode_input(input) # [B, D] # return last def _encode_20_raw(self, input): """ 取第20层的池化向量(不投影)。对需要的骨干做必要的输入整形。 """ mb = getattr(self, "model_backbone", None) idx20 = int(getattr(self, "dual_layer_idx", 20)) # 1) 需要特殊处理图像维度的骨干:LLAVA_NEXT + QWEN2 系列 if mb in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION}: inp = dict(input) # 避免原地改 if "pixel_values" in inp and isinstance(inp["pixel_values"], torch.Tensor) and inp["pixel_values"].dim() == 5: # [B, 1, 3, H, W] -> [B, 3, H, W] inp["pixel_values"] = inp["pixel_values"].squeeze(1) if "image_sizes" in inp and isinstance(inp["image_sizes"], torch.Tensor) and inp["image_sizes"].dim() >= 3: # [B, 1, 2] or [B, 1, 2, ...] -> squeeze 第2维 inp["image_sizes"] = inp["image_sizes"].squeeze(1) out = self.encoder(**inp, return_dict=True, output_hidden_states=True) hs = getattr(out, "hidden_states", None) if hs is None: # 某些实现可能有不同字段名,兜底尝试(通常不会走到) if hasattr(out, "text_hidden_states"): hs = out.text_hidden_states else: raise RuntimeError("hidden_states is None; ensure output_hidden_states=True and trust_remote_code=True.") # hidden_states[0] 是 embeddings,1..L 是每层输出 L = len(hs) - 1 if idx20 < 0: layer_idx = idx20 else: layer_idx = max(1, min(idx20, L)) if L >= 1 else -1 rep20 = self._pooling(hs[layer_idx], inp["attention_mask"]) return rep20 # 2) 不支持中间层的骨干:退化为“最后一层” if mb in {INTERNVIDEO2, GME, LamRA, LamRA_QWEN2_5, COLPALI}: last = self.encode_input(input) # [B, D] return last # 3) 其他(默认 HF CausalLM 等,支持 hidden_states) out = self.encoder(**input, return_dict=True, output_hidden_states=True) hs = out.hidden_states L = len(hs) - 1 if idx20 < 0: layer_idx = idx20 else: layer_idx = max(1, min(idx20, L)) if L >= 1 else -1 rep20 = self._pooling(hs[layer_idx], input["attention_mask"]) return rep20 # [新增] 20层 + MLP 投影 + 可选归一化 def _encode_20_proj(self, input): rep20 = self._encode_20_raw(input) # [B, D] rep20 = self.proj20(rep20) # [B, D] if self.normalize: rep20 = F.normalize(rep20, p=2, dim=-1) return rep20 # def _encode_query_dual(self, input): # """ # 返回 [B, 2, D]: 第20层与最后一层的池化向量。 # 对不支持 hidden_states 的 backbone,回退为两份相同的最后一层。 # """ # mb = getattr(self, "model_backbone", None) # def norm(x): # return F.normalize(x, p=2, dim=-1) if self.normalize else x # # 支持 hidden_states 的通用分支(默认HF) # if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI, LLAVA_NEXT]: # out = self.encoder(**input, return_dict=True, output_hidden_states=True) # hs = out.hidden_states # [emb, layer1, ..., layerL] # idx20 = self.dual_layer_idx # if idx20 < 0: # idx20 = len(hs) + idx20 # 允许负索引 # idx20 = max(1, min(idx20, len(hs) - 1)) # 1..L # rep20 = self._pooling(hs[idx20], input['attention_mask']) # replast = self._pooling(hs[-1], input['attention_mask']) # rep20 = self.proj20(rep20) # [修改] 20层经过投影头 # rep20, replast = norm(rep20), norm(replast) # reps = torch.stack([rep20, replast], dim=1) # [B, 2, D] # return reps # # [修改] LLAVA_NEXT 分支:支持两层 # if mb == LLAVA_NEXT: # input = dict(input) # 避免原地修改 # input['pixel_values'] = input['pixel_values'].squeeze(dim=1) # input['image_sizes'] = input['image_sizes'].squeeze(dim=1) # out = self.encoder(**input, return_dict=True, output_hidden_states=True) # hs = out.hidden_states # idx20 = self.dual_layer_idx # if idx20 < 0: # idx20 = len(hs) + idx20 # idx20 = max(1, min(idx20, len(hs) - 1)) # rep20 = self._pooling(hs[idx20], input['attention_mask']) # replast = self._pooling(hs[-1], input['attention_mask']) # rep20 = self.proj20(rep20) # [修改] 20层经过投影头 # if self.normalize: # rep20 = F.normalize(rep20, p=2, dim=-1) # replast = F.normalize(replast, p=2, dim=-1) # reps = torch.stack([rep20, replast], dim=1) # return reps # # 回退:两份最后一层,其中第一份过 MLP 以保持接口一致 # last = self.encode_input(input) # [B, D] # rep20 = self.proj20(last) # if self.normalize: # rep20 = F.normalize(rep20, p=2, dim=-1) # last = F.normalize(last, p=2, dim=-1) # return torch.stack([rep20, last], dim=1) def _encode_query_dual(self, input): """ 返回 [B, 2, D]: 第20层(过MLP) 与 最后一层 的池化向量。 对不支持 hidden_states 的 backbone,回退为两份相同的最后一层(第一份过MLP)。 """ mb = getattr(self, "model_backbone", None) def norm(x): return F.normalize(x, p=2, dim=-1) if self.normalize else x # 支持 hidden_states 的通用分支(默认HF) if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI, LLAVA_NEXT] and not self._is_qwen2_series(): out = self.encoder(**input, return_dict=True, output_hidden_states=True) hs = out.hidden_states # [emb, layer1, ..., layerL] idx20 = self.dual_layer_idx if idx20 < 0: idx20 = len(hs) + idx20 idx20 = max(1, min(idx20, len(hs) - 1)) rep20 = self._pooling(hs[idx20], input['attention_mask']) replast = self._pooling(hs[-1], input['attention_mask']) rep20 = self.proj20(rep20) rep20, replast = norm(rep20), norm(replast) return torch.stack([rep20, replast], dim=1) # [B, 2, D] # LLAVA_NEXT + QWEN2 系列:先 squeeze 再前向 if mb == LLAVA_NEXT or self._is_qwen2_series(): inp = self._squeeze_mm_inputs(input) out = self.encoder(**inp, return_dict=True, output_hidden_states=True) hs = out.hidden_states idx20 = self.dual_layer_idx if idx20 < 0: idx20 = len(hs) + idx20 idx20 = max(1, min(idx20, len(hs) - 1)) rep20 = self._pooling(hs[idx20], inp['attention_mask']) replast = self._pooling(hs[-1], inp['attention_mask']) rep20 = self.proj20(rep20) rep20, replast = norm(rep20), norm(replast) return torch.stack([rep20, replast], dim=1) # 回退:两份最后一层,其中第一份过 MLP 以保持接口一致 last = self.encode_input(input) # [B, D] rep20 = self.proj20(last) if self.normalize: rep20 = F.normalize(rep20, p=2, dim=-1) last = F.normalize(last, p=2, dim=-1) return torch.stack([rep20, last], dim=1) # def _encode_target_dual(self, input): # """ # 返回 [B, 2, D]: cand 的第20层与最后一层池化向量。 # 对不支持 hidden_states 的 backbone,回退为两份相同的最后一层。 # """ # mb = getattr(self, "model_backbone", None) # def norm(x): # return F.normalize(x, p=2, dim=-1) if self.normalize else x # if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI, LLAVA_NEXT]: # out = self.encoder(**input, return_dict=True, output_hidden_states=True) # hs = out.hidden_states # idx20 = self.dual_layer_idx # if idx20 < 0: # idx20 = len(hs) + idx20 # idx20 = max(1, min(idx20, len(hs) - 1)) # rep20 = self._pooling(hs[idx20], input['attention_mask']) # replast = self._pooling(hs[-1], input['attention_mask']) # rep20 = self.proj20(rep20) # [修改] # rep20, replast = norm(rep20), norm(replast) # reps = torch.stack([rep20, replast], dim=1) # return reps # if mb == LLAVA_NEXT: # input = dict(input) # input['pixel_values'] = input['pixel_values'].squeeze(dim=1) # input['image_sizes'] = input['image_sizes'].squeeze(dim=1) # out = self.encoder(**input, return_dict=True, output_hidden_states=True) # hs = out.hidden_states # idx20 = self.dual_layer_idx # if idx20 < 0: # idx20 = len(hs) + idx20 # idx20 = max(1, min(idx20, len(hs) - 1)) # rep20 = self._pooling(hs[idx20], input['attention_mask']) # replast = self._pooling(hs[-1], input['attention_mask']) # rep20 = self.proj20(rep20) # [修改] # if self.normalize: # rep20 = F.normalize(rep20, p=2, dim=-1) # replast = F.normalize(replast, p=2, dim=-1) # reps = torch.stack([rep20, replast], dim=1) # return reps # last = self.encode_input(input) # [B, D] # rep20 = self.proj20(last) # if self.normalize: # rep20 = F.normalize(rep20, p=2, dim=-1) # last = F.normalize(last, p=2, dim=-1) # return torch.stack([rep20, last], dim=1) def _encode_target_dual(self, input): """ 返回 [B, 2, D]: cand 的第20层(过MLP) 与 最后一层 池化向量。 对不支持 hidden_states 的 backbone,回退为两份相同的最后一层(第一份过MLP)。 """ mb = getattr(self, "model_backbone", None) def norm(x): return F.normalize(x, p=2, dim=-1) if self.normalize else x if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI, LLAVA_NEXT] and not self._is_qwen2_series(): out = self.encoder(**input, return_dict=True, output_hidden_states=True) hs = out.hidden_states idx20 = self.dual_layer_idx if idx20 < 0: idx20 = len(hs) + idx20 idx20 = max(1, min(idx20, len(hs) - 1)) rep20 = self._pooling(hs[idx20], input['attention_mask']) replast = self._pooling(hs[-1], input['attention_mask']) rep20 = self.proj20(rep20) rep20, replast = norm(rep20), norm(replast) return torch.stack([rep20, replast], dim=1) if mb == LLAVA_NEXT or self._is_qwen2_series(): inp = self._squeeze_mm_inputs(input) out = self.encoder(**inp, return_dict=True, output_hidden_states=True) hs = out.hidden_states idx20 = self.dual_layer_idx if idx20 < 0: idx20 = len(hs) + idx20 idx20 = max(1, min(idx20, len(hs) - 1)) rep20 = self._pooling(hs[idx20], inp['attention_mask']) replast = self._pooling(hs[-1], inp['attention_mask']) rep20 = self.proj20(rep20) rep20, replast = norm(rep20), norm(replast) return torch.stack([rep20, replast], dim=1) last = self.encode_input(input) # [B, D] rep20 = self.proj20(last) if self.normalize: rep20 = F.normalize(rep20, p=2, dim=-1) last = F.normalize(last, p=2, dim=-1) return torch.stack([rep20, last], dim=1) # def encode_input(self, input): # def encode_input(self, input, layer_indices=None): # if getattr(self, "model_backbone", None) == INTERNVIDEO2: # if "input_ids" in input.keys(): # # text side # text_output = self.encoder.get_text_encoder()( # input["input_ids"], # attention_mask=input["attention_mask"], # return_dict=True, # mode="text", # ) # text_embeds = text_output.last_hidden_state # pooled_text_embeds = text_embeds[:, 0] # pooled_output = self.encoder.text_proj(pooled_text_embeds) # pooled_output /= pooled_output.norm(dim=-1, keepdim=True) # return pooled_output # else: # _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) # vfeat = self.encoder.vision_proj(vfeat) # vfeat /= vfeat.norm(dim=-1, keepdim=True) # return vfeat # elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]: # # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) # texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen # images = [] # for imgs in input['images']: # # if multi images are given, select the middle frame only # if isinstance(imgs, list): # imgs = imgs[len(imgs) // 2] # assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list # images.append(imgs) # else: # images.append(imgs) # pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) # return pooled_output # elif getattr(self, "model_backbone", None) == COLPALI: # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) # return pooled_output # elif getattr(self, "model_backbone", None) == LLAVA_NEXT: # input['pixel_values'] = input['pixel_values'].squeeze(dim=1) # input['image_sizes'] = input['image_sizes'].squeeze(dim=1) # hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) # hidden_states = hidden_states.hidden_states[-1] # pooled_output = self._pooling(hidden_states, input['attention_mask']) # return pooled_output # else: # # hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) # # # hidden_states = self.encoder(**input, compression_rate=compression_rate, return_dict=True, output_hidden_states=True) # # hidden_states = hidden_states.hidden_states[-1] # # pooled_output = self._pooling(hidden_states, input['attention_mask']) # # return pooled_output # # 默认HF模型:支持 hidden_states # out = self.encoder(**input, return_dict=True, output_hidden_states=True) # hs_list = out.hidden_states # if layer_indices is None or isinstance(layer_indices, int): # h = hs_list[-1] if layer_indices is None else hs_list[layer_indices] # reps = self._pooling(h, input['attention_mask']) # return reps # else: # reps_list = [] # for idx in layer_indices: # h = hs_list[idx] # r = self._pooling(h, input['attention_mask']) # reps_list.append(r) # reps = torch.stack(reps_list, dim=1) # [B, L, D] # return reps def encode_input(self, input, layer_indices=None): mb = getattr(self, "model_backbone", None) if mb == INTERNVIDEO2: if "input_ids" in input.keys(): text_output = self.encoder.get_text_encoder()( input["input_ids"], attention_mask=input["attention_mask"], return_dict=True, mode="text", ) text_embeds = text_output.last_hidden_state pooled_text_embeds = text_embeds[:, 0] pooled_output = self.encoder.text_proj(pooled_text_embeds) pooled_output /= pooled_output.norm(dim=-1, keepdim=True) return pooled_output else: _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) vfeat = self.encoder.vision_proj(vfeat) vfeat /= vfeat.norm(dim=-1, keepdim=True) return vfeat elif mb in [GME, LamRA, LamRA_QWEN2_5]: texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] images = [] for imgs in input['images']: if isinstance(imgs, list): imgs = imgs[len(imgs) // 2] assert not isinstance(imgs, list) images.append(imgs) else: images.append(imgs) pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) return pooled_output elif mb == COLPALI: pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) return pooled_output elif mb == LLAVA_NEXT or self._is_qwen2_series(): inp = self._squeeze_mm_inputs(input) out = self.encoder(**inp, return_dict=True, output_hidden_states=True) h_last = out.hidden_states[-1] pooled_output = self._pooling(h_last, inp['attention_mask']) return pooled_output else: out = self.encoder(**input, return_dict=True, output_hidden_states=True) hs_list = out.hidden_states if layer_indices is None or isinstance(layer_indices, int): h = hs_list[-1] if layer_indices is None else hs_list[layer_indices] reps = self._pooling(h, input['attention_mask']) return reps else: reps_list = [] for idx in layer_indices: h = hs_list[idx] r = self._pooling(h, input['attention_mask']) reps_list.append(r) return torch.stack(reps_list, dim=1) # [B, V, D] def _pooling(self, last_hidden_state, attention_mask): if self.pooling == 'last' or self.pooling == 'eos': left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) batch_size = last_hidden_state.shape[0] if left_padding: # Get the vectors at the last position reps = last_hidden_state[torch.arange(batch_size), -1, :] else: # Calculate last 1 position in the original tensor eos_indices = attention_mask.sum(dim=1) - 1 # Get the vectors at the last 1 position of each attention mask reps = last_hidden_state[ torch.arange(batch_size, device=last_hidden_state.device), eos_indices] else: raise NotImplementedError if self.normalize: reps = torch.nn.functional.normalize(reps, p=2, dim=-1) return reps @classmethod def build(cls, model_args: ModelArguments, **kwargs): config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) variant = getattr(config, "backbone_variant", None) if variant == "layerprune": model_backbone = "QWEN2_VL_LayerPrune" else: model_backbone = get_backbone_name(hf_config=config) print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}') # Loading the base model if model_backbone == PHI3V: config._attn_implementation = "eager" config.padding_side = "right" config.use_cache = False base_model = Phi3VForCausalLM.from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone == LLAVA_NEXT: config.use_cache = False config.padding_side = "left" base_model = LlavaNextForConditionalGeneration.from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone in [QWEN2_VL, QWEN2_5_VL]: config._attn_implementation = "flash_attention_2" config.padding_side = "left" config.use_cache = False base_model = backbone2model[model_backbone].from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone in ["QWEN2_VL_LayerPrune"]: config._attn_implementation = "flash_attention_2" config.padding_side = "left" config.use_cache = False base_model = backbone2model[model_backbone].from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]: config._attn_implementation = "flash_attention_2" config.padding_side = "left" config.use_cache = False from .utils import parse_layer_type lm_qwen_layer = 28 vis_qwen_layer = 32 lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer) vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer) base_model = backbone2model[model_backbone].from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, lm_skip_layer=lm_skip_layer, vis_skip_layer=vis_skip_layer, ) else: config.use_cache = False base_model = cls.TRANSFORMER_CLS.from_pretrained( model_args.model_name, **kwargs, config=config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, trust_remote_code=True) if model_args.lora: print_master(f'Loading lora adapter from {base_model}') lora_config = LoraConfig( r=model_args.lora_r, lora_alpha=model_args.lora_alpha, target_modules=model_args.lora_target_modules.split(','), lora_dropout=model_args.lora_dropout, init_lora_weights="gaussian", use_dora=True, inference_mode=False ) lora_model = get_peft_model(base_model, lora_config) model = cls( encoder=lora_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) else: model = cls( encoder=base_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) return model @classmethod def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs): # Loading the base model model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) if not hasattr(model_args, "model_backbone") or not model_args.model_backbone: model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type) setattr(model_args, 'model_backbone', model_backbone) print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}') if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}: config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config._attn_implementation = "flash_attention_2" config.vision_config._attn_implementation = "flash_attention_2" base_model = backbone2model[model_args.model_backbone].from_pretrained( model_args.model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, config=config ) elif model_args.model_backbone == PHI3V: config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False config.padding_side = "right" base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True) base_model.padding_side = "right" elif model_args.model_backbone == INTERNVIDEO2: print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}') config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", trust_remote_code=True) base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config, trust_remote_code=True) elif model_args.model_backbone == GME: base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor']) setattr(base_model, 'config', config) elif model_args.model_backbone == LamRA: base_model = LamRAQwen2VL(model_args.model_name) setattr(base_model, 'config', config) elif model_args.model_backbone == LamRA_QWEN2_5: base_model = LamRAQwen25VL(model_args.model_name) setattr(base_model, 'config', config) elif model_args.model_backbone == COLPALI: base_model = ColPali.from_pretrained(model_args.model_name) setattr(base_model, 'config', config) else: # Loading external base model from HF config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False base_model = cls.TRANSFORMER_CLS.from_pretrained( model_name_or_path, **kwargs, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True) # Building the model on top of the base if model_args.lora: print_master(f'Loading LoRA from {model_name_or_path}') lora_config = LoraConfig.from_pretrained(model_name_or_path) lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable) lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable) if not is_trainable: lora_model = lora_model.merge_and_unload() model = cls( encoder=lora_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) else: model = cls( encoder=base_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) model.model_backbone = model_args.model_backbone try: ckpt_dir = model_args.checkpoint_path or model_args.model_name extra_candidates = [ os.path.join(ckpt_dir, "mmeb_extra.pt"), os.path.join(ckpt_dir, "mmeb_extra.bin"), os.path.join(ckpt_dir, "extra_heads.pt"), os.path.join(ckpt_dir, "proj20.pt"), ] extra_path = next((p for p in extra_candidates if os.path.isfile(p)), None) if extra_path: extra_sd = torch.load(extra_path, map_location="cpu") missing, unexpected = model.load_state_dict(extra_sd, strict=False) print_master(f"Loaded extra heads from {extra_path}. " f"missing={len(missing)}, unexpected={len(unexpected)}") except Exception as e: print_master(f"[WARN] Failed to load extra heads: {e}") return model def save(self, output_dir: str): self.encoder.save_pretrained(output_dir) def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs): # GradCache:只给一侧,返回表示 if qry is not None and tgt is None: if self.training_stage == 1: # [修改] 阶段1:仅20层MLP,返回 [B, D] qry_reps = self._encode_20_proj(qry) else: # [修改] 阶段2:返回 [B, 2, D](20层MLP + 最后一层) qry_reps = self._encode_query_dual(qry) return {"qry_reps": qry_reps, "tgt_reps": None} if tgt is not None and qry is None: if self.training_stage == 1: tgt_reps = self._encode_20_proj(tgt) # [B, D] else: tgt_reps = self._encode_target_dual(tgt) # [B, 2, D] return {"qry_reps": None, "tgt_reps": tgt_reps} # 非 GradCache:两侧同时给,直接算损失 if qry is not None and tgt is not None: if self.training_stage == 1: # [修改] 階段1:仅20层MLP的单视角 InfoNCE q = self._encode_20_proj(qry) # [B, D] t = self._encode_20_proj(tgt) # [B, D] if self.is_ddp: q = self._dist_gather_tensor(q) t = self._dist_gather_tensor(t) logits = torch.matmul(q, t.transpose(0, 1)) / self.temperature target = torch.arange(logits.size(0), device=logits.device, dtype=torch.long) loss = self.cross_entropy(logits, target) if self.is_ddp: loss = loss * self.world_size return loss else: # [修改] 階段2:匹配视角(20↔20 + last↔last),加权求和 q = self._encode_query_dual(qry) # [B, 2, D] t = self._encode_target_dual(tgt) # [B, 2, D] if self.is_ddp: q = self._dist_gather_tensor(q) t = self._dist_gather_tensor(t) B = q.size(0) labels = torch.arange(B, device=q.device, dtype=torch.long) alpha = getattr(self, "dual_alpha", 0.2) # v=0: 20层;v=1: 最后一层 loss20 = self.cross_entropy((q[:, 0, :] @ t[:, 0, :].T) / self.temperature, labels) lossL = self.cross_entropy((q[:, 1, :] @ t[:, 1, :].T) / self.temperature, labels) loss = alpha * loss20 + (1.0 - alpha) * lossL if self.is_ddp: loss = loss * self.world_size return loss return {"qry_reps": None, "tgt_reps": None} def _dist_gather_tensor(self, t: Tensor): t = t.contiguous() all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] dist.all_gather(all_tensors, t) all_tensors[self.process_rank] = t all_tensors = torch.cat(all_tensors, dim=0) return all_tensors def compute_similarity(self, q_reps, p_reps): return torch.matmul(q_reps, p_reps.transpose(0, 1))