# from typing import Dict # import torch # import torch.distributed as dist # from torch import nn, Tensor # from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig # from peft import LoraConfig, get_peft_model, PeftModel # from src.model.processor import QWEN2_5_VL_TOKENSELECTION # from src.arguments import ModelArguments, TrainingArguments # from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \ # backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V # from src.arguments import ModelArguments # from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \ # QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI # # 新增:分别引入 TokenPooling 版和 VisionZip 版 # from src.model.vlm_backbone.qwen2_vl_token_pooling.modeling_qwen2_vl import ( # Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationTokenPooling, # ) # from src.model.vlm_backbone.qwen2_vl_visionzip.modeling_qwen2_vl import ( # Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationVisionZip, # ) # from src.model.vlm_backbone.qwen2_5_vl_token_pooling.modeling_qwen2_5_vl import ( # Qwen2_5_VLForConditionalGeneration as Qwen2_5VLForConditionalGenerationTokenPooling, # ) # from src.model.vlm_backbone.qwen2_5_vl_visionzip.modeling_qwen2_5_vl import ( # Qwen2_5_VLForConditionalGeneration as Qwen2_5VLForConditionalGenerationVisionZip, # ) # from src.model.baseline_backbone.colpali import ColPali # from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL # from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL # from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL # from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM # from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration # from transformers import modeling_utils # if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: # modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise'] # def _ensure_pad_token_id_on_model(base_model): # """ # Ensure base_model.config.pad_token_id is a valid int. # Fallback order: config.pad_token_id -> config.eos_token_id -> 0 # Also sync generation_config.pad_token_id if present. # """ # pad_id = getattr(base_model.config, "pad_token_id", None) # if pad_id is None: # pad_id = getattr(base_model.config, "eos_token_id", None) # if pad_id is None: # pad_id = 0 # base_model.config.pad_token_id = pad_id # gen_cfg = getattr(base_model, "generation_config", None) # if gen_cfg is not None and getattr(gen_cfg, "pad_token_id", None) is None: # gen_cfg.pad_token_id = base_model.config.pad_token_id # class MMEBModel(nn.Module): # TRANSFORMER_CLS = AutoModelForCausalLM # def __init__(self, # encoder: PreTrainedModel, # pooling: str = 'last', # normalize: bool = False, # temperature: float = 0.02, # ): # super().__init__() # self.config = encoder.config # self.encoder = encoder # self.pooling = pooling # self.normalize = normalize # self.temperature = temperature # self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') # self.is_ddp = dist.is_initialized() # if self.is_ddp: # self.process_rank = dist.get_rank() # self.world_size = dist.get_world_size() # @property # def device(self): # try: # return next(self.parameters()).device # except StopIteration: # return torch.device("cuda" if torch.cuda.is_available() else "cpu") # def encode_input(self, input): # if getattr(self, "model_backbone", None) == INTERNVIDEO2: # if "input_ids" in input.keys(): # # text side # text_output = self.encoder.get_text_encoder()( # input["input_ids"], # attention_mask=input["attention_mask"], # return_dict=True, # mode="text", # ) # text_embeds = text_output.last_hidden_state # pooled_text_embeds = text_embeds[:, 0] # pooled_output = self.encoder.text_proj(pooled_text_embeds) # pooled_output /= pooled_output.norm(dim=-1, keepdim=True) # return pooled_output # else: # _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) # vfeat = self.encoder.vision_proj(vfeat) # vfeat /= vfeat.norm(dim=-1, keepdim=True) # return vfeat # elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]: # # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) # texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen # images = [] # for imgs in input['images']: # # if multi images are given, select the middle frame only # if isinstance(imgs, list): # imgs = imgs[len(imgs) // 2] # assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list # images.append(imgs) # else: # images.append(imgs) # pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) # return pooled_output # elif getattr(self, "model_backbone", None) == COLPALI: # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) # return pooled_output # elif getattr(self, "model_backbone", None) == LLAVA_NEXT: # input['pixel_values'] = input['pixel_values'].squeeze(dim=1) # input['image_sizes'] = input['image_sizes'].squeeze(dim=1) # hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) # hidden_states = hidden_states.hidden_states[-1] # pooled_output = self._pooling(hidden_states, input['attention_mask']) # return pooled_output # else: # hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) # hidden_states = hidden_states.hidden_states[-1] # pooled_output = self._pooling(hidden_states, input['attention_mask']) # return pooled_output # def _pooling(self, last_hidden_state, attention_mask): # if self.pooling == 'last' or self.pooling == 'eos': # left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) # batch_size = last_hidden_state.shape[0] # if left_padding: # # Get the vectors at the last position # reps = last_hidden_state[torch.arange(batch_size), -1, :] # else: # # Calculate last 1 position in the original tensor # eos_indices = attention_mask.sum(dim=1) - 1 # # Get the vectors at the last 1 position of each attention mask # reps = last_hidden_state[ # torch.arange(batch_size, device=last_hidden_state.device), eos_indices] # else: # raise NotImplementedError # if self.normalize: # reps = torch.nn.functional.normalize(reps, p=2, dim=-1) # return reps # @classmethod # def build(cls, model_args: ModelArguments, **kwargs): # config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # model_backbone = get_backbone_name(hf_config=config) # print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}') # base_model = None # <-- ensure defined before branches # # Loading the base model # if model_backbone == PHI3V: # config._attn_implementation = "eager" # config.padding_side = "right" # config.use_cache = False # base_model = Phi3VForCausalLM.from_pretrained( # model_args.model_name, # config=config, # torch_dtype=torch.bfloat16, # low_cpu_mem_usage=True, # ) # elif model_backbone == LLAVA_NEXT: # config.use_cache = False # config.padding_side = "left" # base_model = LlavaNextForConditionalGeneration.from_pretrained( # model_args.model_name, # config=config, # torch_dtype=torch.bfloat16, # low_cpu_mem_usage=True, # ) # elif model_backbone in [QWEN2_VL, QWEN2_5_VL]: # config._attn_implementation = "flash_attention_2" # config.padding_side = "left" # config.use_cache = False # mode = getattr(model_args, "vision_compression", "token_pooling") # # ========= Qwen2-VL ========= # if model_backbone == QWEN2_VL: # if mode == "token_pooling": # BaseCls = Qwen2VLForConditionalGenerationTokenPooling # print_master("[VisionCompression] Qwen2-VL using TokenPooling modeling") # elif mode == "visionzip": # BaseCls = Qwen2VLForConditionalGenerationVisionZip # print_master("[VisionCompression] Qwen2-VL using VisionZip modeling") # else: # "none" 或未知 # BaseCls = backbone2model[model_backbone] # print_master(f"[VisionCompression] Qwen2-VL using vanilla backbone (mode={mode})") # # ========= Qwen2.5-VL ========= # elif model_backbone == QWEN2_5_VL: # if mode == "token_pooling": # BaseCls = Qwen2_5VLForConditionalGenerationTokenPooling # print_master("[VisionCompression] Qwen2.5-VL using TokenPooling modeling") # elif mode == "visionzip": # BaseCls = Qwen2_5VLForConditionalGenerationVisionZip # print_master("[VisionCompression] Qwen2.5-VL using VisionZip modeling") # else: # BaseCls = backbone2model[model_backbone] # print_master(f"[VisionCompression] Qwen2.5-VL using vanilla backbone (mode={mode})") # # ============================= # base_model = BaseCls.from_pretrained( # model_args.model_name, # config=config, # torch_dtype=torch.bfloat16, # low_cpu_mem_usage=True, # ) # elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]: # config._attn_implementation = "flash_attention_2" # config.padding_side = "left" # config.use_cache = False # from .utils import parse_layer_type # lm_qwen_layer = 28 # vis_qwen_layer = 32 # lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer) # vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer) # base_model = backbone2model[model_backbone].from_pretrained( # model_args.model_name, # config=config, # torch_dtype=torch.bfloat16, # low_cpu_mem_usage=True, # lm_skip_layer=lm_skip_layer, # vis_skip_layer=vis_skip_layer, # ) # else: # config.use_cache = False # base_model = cls.TRANSFORMER_CLS.from_pretrained( # model_args.model_name, **kwargs, config=config, # attn_implementation="flash_attention_2", # torch_dtype=torch.bfloat16, # trust_remote_code=True # ) # # <-- call after base_model is assigned # _ensure_pad_token_id_on_model(base_model) # # Build MMEBModel # if model_args.lora: # print_master(f'Loading lora adapter from {base_model}') # lora_config = LoraConfig( # r=model_args.lora_r, # lora_alpha=model_args.lora_alpha, # target_modules=model_args.lora_target_modules.split(','), # lora_dropout=model_args.lora_dropout, # init_lora_weights="gaussian", # use_dora=True, # inference_mode=False # ) # lora_model = get_peft_model(base_model, lora_config) # model = cls( # encoder=lora_model, # pooling=model_args.pooling, # normalize=model_args.normalize, # temperature=model_args.temperature # ) # else: # model = cls( # encoder=base_model, # pooling=model_args.pooling, # normalize=model_args.normalize, # temperature=model_args.temperature # ) # return model # @classmethod # def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs): # # Loading the base model # model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name # config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) # if not hasattr(model_args, "model_backbone") or not model_args.model_backbone: # model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type) # setattr(model_args, 'model_backbone', model_backbone) # print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}') # base_model = None # <-- ensure defined before branches # if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}: # config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # config._attn_implementation = "flash_attention_2" # if hasattr(config, "vision_config") and config.vision_config is not None: # config.vision_config._attn_implementation = "flash_attention_2" # mode = getattr(model_args, "vision_compression", "token_pooling") # # ========= Qwen2-VL ========= # if model_args.model_backbone == QWEN2_VL: # if mode == "token_pooling": # BaseCls = Qwen2VLForConditionalGenerationTokenPooling # print_master("[VisionCompression-load] Qwen2-VL using TokenPooling modeling") # elif mode == "visionzip": # BaseCls = Qwen2VLForConditionalGenerationVisionZip # print_master("[VisionCompression-load] Qwen2-VL using VisionZip modeling") # else: # BaseCls = backbone2model[model_args.model_backbone] # print_master(f"[VisionCompression-load] Qwen2-VL using vanilla backbone (mode={mode})") # # ========= Qwen2.5-VL ========= # elif model_args.model_backbone == QWEN2_5_VL: # if mode == "token_pooling": # BaseCls = Qwen2_5VLForConditionalGenerationTokenPooling # print_master("[VisionCompression-load] Qwen2.5-VL using TokenPooling modeling") # elif mode == "visionzip": # BaseCls = Qwen2_5VLForConditionalGenerationVisionZip # print_master("[VisionCompression-load] Qwen2.5-VL using VisionZip modeling") # else: # BaseCls = backbone2model[model_args.model_backbone] # print_master(f"[VisionCompression-load] Qwen2.5-VL using vanilla backbone (mode={mode})") # # 其它 backbone 走原来的 mapping # else: # BaseCls = backbone2model[model_args.model_backbone] # base_model = BaseCls.from_pretrained( # model_args.model_name, # torch_dtype=torch.bfloat16, # low_cpu_mem_usage=True, # config=config, # ) # elif model_args.model_backbone == PHI3V: # config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # config.use_cache = False # config.padding_side = "right" # base_model = Phi3VForCausalLM.from_pretrained( # model_args.model_name, **kwargs, config=config, # torch_dtype=torch.bfloat16, trust_remote_code=True # ) # base_model.padding_side = "right" # elif model_args.model_backbone == INTERNVIDEO2: # print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}') # config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", trust_remote_code=True) # base_model = backbone2model[model_args.model_backbone].from_pretrained( # "src/model/vlm_backbone/internvideo2/", config=config, trust_remote_code=True # ) # elif model_args.model_backbone == GME: # base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor']) # setattr(base_model, 'config', config) # elif model_args.model_backbone == LamRA: # base_model = LamRAQwen2VL(model_args.model_name) # setattr(base_model, 'config', config) # elif model_args.model_backbone == LamRA_QWEN2_5: # base_model = LamRAQwen25VL(model_args.model_name) # setattr(base_model, 'config', config) # elif model_args.model_backbone == COLPALI: # base_model = ColPali.from_pretrained(model_args.model_name) # setattr(base_model, 'config', config) # else: # # Loading external base model from HF # config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # config.use_cache = False # base_model = cls.TRANSFORMER_CLS.from_pretrained( # model_name_or_path, **kwargs, config=config, # torch_dtype=torch.bfloat16, # trust_remote_code=True # ) # # <-- call after base_model is assigned # _ensure_pad_token_id_on_model(base_model) # # Building the model on top of the base # if model_args.lora: # print_master(f'Loading LoRA from {model_name_or_path}') # lora_config = LoraConfig.from_pretrained(model_name_or_path) # lora_model = PeftModel.from_pretrained( # base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable # ) # lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable) # if not is_trainable: # lora_model = lora_model.merge_and_unload() # model = cls( # encoder=lora_model, # pooling=model_args.pooling, # normalize=model_args.normalize, # temperature=model_args.temperature # ) # else: # model = cls( # encoder=base_model, # pooling=model_args.pooling, # normalize=model_args.normalize, # temperature=model_args.temperature # ) # model.model_backbone = model_args.model_backbone # return model # def save(self, output_dir: str): # self.encoder.save_pretrained(output_dir) # def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs): # qry_reps = self.encode_input(qry) if qry else None # (bsz_per_device, dim) # tgt_reps = self.encode_input(tgt) if tgt else None # (bsz_per_device, dim) # if qry_reps is None or tgt_reps is None: # return {"qry_reps": qry_reps, "tgt_reps": tgt_reps} # if self.is_ddp: # all_qry_reps = self._dist_gather_tensor(qry_reps) # all_tgt_reps = self._dist_gather_tensor(tgt_reps) # else: # all_qry_reps = qry_reps # all_tgt_reps = tgt_reps # scores = self.compute_similarity(all_qry_reps, all_tgt_reps) # scores = scores.view(all_qry_reps.size(0), -1) # target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) # target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0)) # loss = self.cross_entropy(scores / self.temperature, target) # if self.is_ddp: # loss = loss * self.world_size # return loss # def _dist_gather_tensor(self, t: Tensor): # t = t.contiguous() # all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] # dist.all_gather(all_tensors, t) # all_tensors[self.process_rank] = t # all_tensors = torch.cat(all_tensors, dim=0) # return all_tensors # def compute_similarity(self, q_reps, p_reps): # return torch.matmul(q_reps, p_reps.transpose(0, 1)) from typing import Dict import torch import torch.distributed as dist from torch import nn, Tensor from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig from peft import LoraConfig, get_peft_model, PeftModel from src.model.processor import QWEN2_5_VL_TOKENSELECTION from src.arguments import ModelArguments, TrainingArguments from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \ backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V from src.arguments import ModelArguments from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \ QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI # 新增:分别引入 TokenPooling 版和 VisionZip 版 from src.model.vlm_backbone.qwen2_vl_token_pooling.modeling_qwen2_vl import ( Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationTokenPooling, ) from src.model.vlm_backbone.qwen2_vl_visionzip.modeling_qwen2_vl import ( Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationVisionZip, ) from src.model.vlm_backbone.qwen2_5_vl_token_pooling.modeling_qwen2_5_vl import ( Qwen2_5_VLForConditionalGeneration as Qwen2_5VLForConditionalGenerationTokenPooling, ) from src.model.vlm_backbone.qwen2_5_vl_visionzip.modeling_qwen2_5_vl import ( Qwen2_5_VLForConditionalGeneration as Qwen2_5VLForConditionalGenerationVisionZip, ) from src.model.baseline_backbone.colpali import ColPali from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration from transformers import modeling_utils if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise'] def _ensure_pad_token_id_on_model(base_model): """ Ensure base_model.config.pad_token_id is a valid int. Fallback order: config.pad_token_id -> config.eos_token_id -> 0 Also sync generation_config.pad_token_id if present. """ pad_id = getattr(base_model.config, "pad_token_id", None) if pad_id is None: pad_id = getattr(base_model.config, "eos_token_id", None) if pad_id is None: pad_id = 0 base_model.config.pad_token_id = pad_id gen_cfg = getattr(base_model, "generation_config", None) if gen_cfg is not None and getattr(gen_cfg, "pad_token_id", None) is None: gen_cfg.pad_token_id = base_model.config.pad_token_id class MMEBModel(nn.Module): TRANSFORMER_CLS = AutoModelForCausalLM def __init__(self, encoder: PreTrainedModel, pooling: str = 'last', normalize: bool = False, temperature: float = 0.02, ): super().__init__() self.config = encoder.config self.encoder = encoder self.pooling = pooling self.normalize = normalize self.temperature = temperature self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') self.is_ddp = dist.is_initialized() if self.is_ddp: self.process_rank = dist.get_rank() self.world_size = dist.get_world_size() @property def device(self): try: return next(self.parameters()).device except StopIteration: return torch.device("cuda" if torch.cuda.is_available() else "cpu") def encode_input(self, input): if getattr(self, "model_backbone", None) == INTERNVIDEO2: if "input_ids" in input.keys(): # text side text_output = self.encoder.get_text_encoder()( input["input_ids"], attention_mask=input["attention_mask"], return_dict=True, mode="text", ) text_embeds = text_output.last_hidden_state pooled_text_embeds = text_embeds[:, 0] pooled_output = self.encoder.text_proj(pooled_text_embeds) pooled_output /= pooled_output.norm(dim=-1, keepdim=True) return pooled_output else: _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) vfeat = self.encoder.vision_proj(vfeat) vfeat /= vfeat.norm(dim=-1, keepdim=True) return vfeat elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]: # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen images = [] for imgs in input['images']: # if multi images are given, select the middle frame only if isinstance(imgs, list): imgs = imgs[len(imgs) // 2] assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list images.append(imgs) else: images.append(imgs) pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) return pooled_output elif getattr(self, "model_backbone", None) == COLPALI: pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) return pooled_output elif getattr(self, "model_backbone", None) == LLAVA_NEXT: input['pixel_values'] = input['pixel_values'].squeeze(dim=1) input['image_sizes'] = input['image_sizes'].squeeze(dim=1) hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) hidden_states = hidden_states.hidden_states[-1] pooled_output = self._pooling(hidden_states, input['attention_mask']) return pooled_output else: outputs = self.encoder(**input, return_dict=True, output_hidden_states=True) last_hidden = outputs.hidden_states[-1] # [B, L', D](VisionZip 后 L' 可能变短) # 优先使用模型 forward 返回的 post attention_mask;没有则回退到输入 mask post_mask = getattr(outputs, "attention_mask", None) src_mask = input.get("attention_mask", None) use_mask = post_mask if (post_mask is not None) else src_mask pooled_output = self._pooling(last_hidden, use_mask) return pooled_output def _pooling(self, last_hidden_state, attention_mask): """ 健壮的 eos pooling: - 若 attention_mask 为空或长度与 last_hidden_state 不一致,则回退到每样本最后一位(左 padding 默认成立) - 正常情况下用 mask.sum(dim=1)-1 取有效最后位,并做 clamp 防越界 """ if self.pooling in ('last', 'eos'): B, L, D = last_hidden_state.shape device = last_hidden_state.device # 回退条件:无 mask 或长度不匹配 if (attention_mask is None) or (attention_mask.shape[1] != L): reps = last_hidden_state[:, -1, :] else: # 计算每行有效长度(>=1),并转换为有效索引 [0, L-1] # 注意:attention_mask 可能是 float/bfloat16,统一转 long 计算 valid_len = attention_mask.to(torch.long).sum(dim=1) # [B] eos_idx = (valid_len - 1).clamp(min=0, max=L - 1) # [B] reps = last_hidden_state[torch.arange(B, device=device), eos_idx, :] else: raise NotImplementedError if self.normalize: reps = torch.nn.functional.normalize(reps, p=2, dim=-1) return reps @classmethod def build(cls, model_args: ModelArguments, **kwargs): config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) model_backbone = get_backbone_name(hf_config=config) print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}') base_model = None # <-- ensure defined before branches # Loading the base model if model_backbone == PHI3V: config._attn_implementation = "eager" config.padding_side = "right" config.use_cache = False base_model = Phi3VForCausalLM.from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone == LLAVA_NEXT: config.use_cache = False config.padding_side = "left" base_model = LlavaNextForConditionalGeneration.from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone in [QWEN2_VL, QWEN2_5_VL]: config._attn_implementation = "flash_attention_2" config.padding_side = "left" config.use_cache = False mode = getattr(model_args, "vision_compression", "token_pooling") # ========= Qwen2-VL ========= if model_backbone == QWEN2_VL: if mode == "token_pooling": BaseCls = Qwen2VLForConditionalGenerationTokenPooling print_master("[VisionCompression] Qwen2-VL using TokenPooling modeling") elif mode == "visionzip": BaseCls = Qwen2VLForConditionalGenerationVisionZip print_master("[VisionCompression] Qwen2-VL using VisionZip modeling") else: # "none" 或未知 BaseCls = backbone2model[model_backbone] print_master(f"[VisionCompression] Qwen2-VL using vanilla backbone (mode={mode})") # ========= Qwen2.5-VL ========= elif model_backbone == QWEN2_5_VL: if mode == "token_pooling": BaseCls = Qwen2_5VLForConditionalGenerationTokenPooling print_master("[VisionCompression] Qwen2.5-VL using TokenPooling modeling") elif mode == "visionzip": BaseCls = Qwen2_5VLForConditionalGenerationVisionZip print_master("[VisionCompression] Qwen2.5-VL using VisionZip modeling") else: BaseCls = backbone2model[model_backbone] print_master(f"[VisionCompression] Qwen2.5-VL using vanilla backbone (mode={mode})") # ============================= base_model = BaseCls.from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]: config._attn_implementation = "flash_attention_2" config.padding_side = "left" config.use_cache = False from .utils import parse_layer_type lm_qwen_layer = 28 vis_qwen_layer = 32 lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer) vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer) base_model = backbone2model[model_backbone].from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, lm_skip_layer=lm_skip_layer, vis_skip_layer=vis_skip_layer, ) else: config.use_cache = False base_model = cls.TRANSFORMER_CLS.from_pretrained( model_args.model_name, **kwargs, config=config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, trust_remote_code=True ) # <-- call after base_model is assigned _ensure_pad_token_id_on_model(base_model) # Build MMEBModel if model_args.lora: print_master(f'Loading lora adapter from {base_model}') lora_config = LoraConfig( r=model_args.lora_r, lora_alpha=model_args.lora_alpha, target_modules=model_args.lora_target_modules.split(','), lora_dropout=model_args.lora_dropout, init_lora_weights="gaussian", use_dora=True, inference_mode=False ) lora_model = get_peft_model(base_model, lora_config) model = cls( encoder=lora_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) else: model = cls( encoder=base_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) return model @classmethod def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs): # Loading the base model model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) if not hasattr(model_args, "model_backbone") or not model_args.model_backbone: model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type) setattr(model_args, 'model_backbone', model_backbone) print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}') base_model = None # <-- ensure defined before branches if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}: config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config._attn_implementation = "flash_attention_2" if hasattr(config, "vision_config") and config.vision_config is not None: config.vision_config._attn_implementation = "flash_attention_2" mode = getattr(model_args, "vision_compression", "token_pooling") # ========= Qwen2-VL ========= if model_args.model_backbone == QWEN2_VL: if mode == "token_pooling": BaseCls = Qwen2VLForConditionalGenerationTokenPooling print_master("[VisionCompression-load] Qwen2-VL using TokenPooling modeling") elif mode == "visionzip": BaseCls = Qwen2VLForConditionalGenerationVisionZip print_master("[VisionCompression-load] Qwen2-VL using VisionZip modeling") else: BaseCls = backbone2model[model_args.model_backbone] print_master(f"[VisionCompression-load] Qwen2-VL using vanilla backbone (mode={mode})") # ========= Qwen2.5-VL ========= elif model_args.model_backbone == QWEN2_5_VL: if mode == "token_pooling": BaseCls = Qwen2_5VLForConditionalGenerationTokenPooling print_master("[VisionCompression-load] Qwen2.5-VL using TokenPooling modeling") elif mode == "visionzip": BaseCls = Qwen2_5VLForConditionalGenerationVisionZip print_master("[VisionCompression-load] Qwen2.5-VL using VisionZip modeling") else: BaseCls = backbone2model[model_args.model_backbone] print_master(f"[VisionCompression-load] Qwen2.5-VL using vanilla backbone (mode={mode})") # 其它 backbone 走原来的 mapping else: BaseCls = backbone2model[model_args.model_backbone] base_model = BaseCls.from_pretrained( model_args.model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, config=config, ) elif model_args.model_backbone == PHI3V: config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False config.padding_side = "right" base_model = Phi3VForCausalLM.from_pretrained( model_args.model_name, **kwargs, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True ) base_model.padding_side = "right" elif model_args.model_backbone == INTERNVIDEO2: print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}') config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", trust_remote_code=True) base_model = backbone2model[model_args.model_backbone].from_pretrained( "src/model/vlm_backbone/internvideo2/", config=config, trust_remote_code=True ) elif model_args.model_backbone == GME: base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor']) setattr(base_model, 'config', config) elif model_args.model_backbone == LamRA: base_model = LamRAQwen2VL(model_args.model_name) setattr(base_model, 'config', config) elif model_args.model_backbone == LamRA_QWEN2_5: base_model = LamRAQwen25VL(model_args.model_name) setattr(base_model, 'config', config) elif model_args.model_backbone == COLPALI: base_model = ColPali.from_pretrained(model_args.model_name) setattr(base_model, 'config', config) else: # Loading external base model from HF config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False base_model = cls.TRANSFORMER_CLS.from_pretrained( model_name_or_path, **kwargs, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True ) # <-- call after base_model is assigned _ensure_pad_token_id_on_model(base_model) # Building the model on top of the base if model_args.lora: print_master(f'Loading LoRA from {model_name_or_path}') lora_config = LoraConfig.from_pretrained(model_name_or_path) lora_model = PeftModel.from_pretrained( base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable ) lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable) if not is_trainable: lora_model = lora_model.merge_and_unload() model = cls( encoder=lora_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) else: model = cls( encoder=base_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) model.model_backbone = model_args.model_backbone return model def save(self, output_dir: str): self.encoder.save_pretrained(output_dir) def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs): qry_reps = self.encode_input(qry) if qry else None # (bsz_per_device, dim) tgt_reps = self.encode_input(tgt) if tgt else None # (bsz_per_device, dim) if qry_reps is None or tgt_reps is None: return {"qry_reps": qry_reps, "tgt_reps": tgt_reps} if self.is_ddp: all_qry_reps = self._dist_gather_tensor(qry_reps) all_tgt_reps = self._dist_gather_tensor(tgt_reps) else: all_qry_reps = qry_reps all_tgt_reps = tgt_reps scores = self.compute_similarity(all_qry_reps, all_tgt_reps) scores = scores.view(all_qry_reps.size(0), -1) target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0)) loss = self.cross_entropy(scores / self.temperature, target) if self.is_ddp: loss = loss * self.world_size return loss def _dist_gather_tensor(self, t: Tensor): t = t.contiguous() all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] dist.all_gather(all_tensors, t) all_tensors[self.process_rank] = t all_tensors = torch.cat(all_tensors, dim=0) return all_tensors def compute_similarity(self, q_reps, p_reps): return torch.matmul(q_reps, p_reps.transpose(0, 1))