code_SAS_VLM2Vec / src /model /model_vision_compression.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
# from typing import Dict
# import torch
# import torch.distributed as dist
# from torch import nn, Tensor
# from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig
# from peft import LoraConfig, get_peft_model, PeftModel
# from src.model.processor import QWEN2_5_VL_TOKENSELECTION
# from src.arguments import ModelArguments, TrainingArguments
# from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \
# backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V
# from src.arguments import ModelArguments
# from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \
# QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI
# # 新增:分别引入 TokenPooling 版和 VisionZip 版
# from src.model.vlm_backbone.qwen2_vl_token_pooling.modeling_qwen2_vl import (
# Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationTokenPooling,
# )
# from src.model.vlm_backbone.qwen2_vl_visionzip.modeling_qwen2_vl import (
# Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationVisionZip,
# )
# from src.model.vlm_backbone.qwen2_5_vl_token_pooling.modeling_qwen2_5_vl import (
# Qwen2_5_VLForConditionalGeneration as Qwen2_5VLForConditionalGenerationTokenPooling,
# )
# from src.model.vlm_backbone.qwen2_5_vl_visionzip.modeling_qwen2_5_vl import (
# Qwen2_5_VLForConditionalGeneration as Qwen2_5VLForConditionalGenerationVisionZip,
# )
# from src.model.baseline_backbone.colpali import ColPali
# from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL
# from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL
# from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL
# from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM
# from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration
# from transformers import modeling_utils
# if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
# modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise']
# def _ensure_pad_token_id_on_model(base_model):
# """
# Ensure base_model.config.pad_token_id is a valid int.
# Fallback order: config.pad_token_id -> config.eos_token_id -> 0
# Also sync generation_config.pad_token_id if present.
# """
# pad_id = getattr(base_model.config, "pad_token_id", None)
# if pad_id is None:
# pad_id = getattr(base_model.config, "eos_token_id", None)
# if pad_id is None:
# pad_id = 0
# base_model.config.pad_token_id = pad_id
# gen_cfg = getattr(base_model, "generation_config", None)
# if gen_cfg is not None and getattr(gen_cfg, "pad_token_id", None) is None:
# gen_cfg.pad_token_id = base_model.config.pad_token_id
# class MMEBModel(nn.Module):
# TRANSFORMER_CLS = AutoModelForCausalLM
# def __init__(self,
# encoder: PreTrainedModel,
# pooling: str = 'last',
# normalize: bool = False,
# temperature: float = 0.02,
# ):
# super().__init__()
# self.config = encoder.config
# self.encoder = encoder
# self.pooling = pooling
# self.normalize = normalize
# self.temperature = temperature
# self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
# self.is_ddp = dist.is_initialized()
# if self.is_ddp:
# self.process_rank = dist.get_rank()
# self.world_size = dist.get_world_size()
# @property
# def device(self):
# try:
# return next(self.parameters()).device
# except StopIteration:
# return torch.device("cuda" if torch.cuda.is_available() else "cpu")
# def encode_input(self, input):
# if getattr(self, "model_backbone", None) == INTERNVIDEO2:
# if "input_ids" in input.keys():
# # text side
# text_output = self.encoder.get_text_encoder()(
# input["input_ids"],
# attention_mask=input["attention_mask"],
# return_dict=True,
# mode="text",
# )
# text_embeds = text_output.last_hidden_state
# pooled_text_embeds = text_embeds[:, 0]
# pooled_output = self.encoder.text_proj(pooled_text_embeds)
# pooled_output /= pooled_output.norm(dim=-1, keepdim=True)
# return pooled_output
# else:
# _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True)
# vfeat = self.encoder.vision_proj(vfeat)
# vfeat /= vfeat.norm(dim=-1, keepdim=True)
# return vfeat
# elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]:
# # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
# texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen
# images = []
# for imgs in input['images']:
# # if multi images are given, select the middle frame only
# if isinstance(imgs, list):
# imgs = imgs[len(imgs) // 2]
# assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list
# images.append(imgs)
# else:
# images.append(imgs)
# pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images)
# return pooled_output
# elif getattr(self, "model_backbone", None) == COLPALI:
# pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
# return pooled_output
# elif getattr(self, "model_backbone", None) == LLAVA_NEXT:
# input['pixel_values'] = input['pixel_values'].squeeze(dim=1)
# input['image_sizes'] = input['image_sizes'].squeeze(dim=1)
# hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True)
# hidden_states = hidden_states.hidden_states[-1]
# pooled_output = self._pooling(hidden_states, input['attention_mask'])
# return pooled_output
# else:
# hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True)
# hidden_states = hidden_states.hidden_states[-1]
# pooled_output = self._pooling(hidden_states, input['attention_mask'])
# return pooled_output
# def _pooling(self, last_hidden_state, attention_mask):
# if self.pooling == 'last' or self.pooling == 'eos':
# left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
# batch_size = last_hidden_state.shape[0]
# if left_padding:
# # Get the vectors at the last position
# reps = last_hidden_state[torch.arange(batch_size), -1, :]
# else:
# # Calculate last 1 position in the original tensor
# eos_indices = attention_mask.sum(dim=1) - 1
# # Get the vectors at the last 1 position of each attention mask
# reps = last_hidden_state[
# torch.arange(batch_size, device=last_hidden_state.device), eos_indices]
# else:
# raise NotImplementedError
# if self.normalize:
# reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
# return reps
# @classmethod
# def build(cls, model_args: ModelArguments, **kwargs):
# config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# model_backbone = get_backbone_name(hf_config=config)
# print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}')
# base_model = None # <-- ensure defined before branches
# # Loading the base model
# if model_backbone == PHI3V:
# config._attn_implementation = "eager"
# config.padding_side = "right"
# config.use_cache = False
# base_model = Phi3VForCausalLM.from_pretrained(
# model_args.model_name,
# config=config,
# torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
# )
# elif model_backbone == LLAVA_NEXT:
# config.use_cache = False
# config.padding_side = "left"
# base_model = LlavaNextForConditionalGeneration.from_pretrained(
# model_args.model_name,
# config=config,
# torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
# )
# elif model_backbone in [QWEN2_VL, QWEN2_5_VL]:
# config._attn_implementation = "flash_attention_2"
# config.padding_side = "left"
# config.use_cache = False
# mode = getattr(model_args, "vision_compression", "token_pooling")
# # ========= Qwen2-VL =========
# if model_backbone == QWEN2_VL:
# if mode == "token_pooling":
# BaseCls = Qwen2VLForConditionalGenerationTokenPooling
# print_master("[VisionCompression] Qwen2-VL using TokenPooling modeling")
# elif mode == "visionzip":
# BaseCls = Qwen2VLForConditionalGenerationVisionZip
# print_master("[VisionCompression] Qwen2-VL using VisionZip modeling")
# else: # "none" 或未知
# BaseCls = backbone2model[model_backbone]
# print_master(f"[VisionCompression] Qwen2-VL using vanilla backbone (mode={mode})")
# # ========= Qwen2.5-VL =========
# elif model_backbone == QWEN2_5_VL:
# if mode == "token_pooling":
# BaseCls = Qwen2_5VLForConditionalGenerationTokenPooling
# print_master("[VisionCompression] Qwen2.5-VL using TokenPooling modeling")
# elif mode == "visionzip":
# BaseCls = Qwen2_5VLForConditionalGenerationVisionZip
# print_master("[VisionCompression] Qwen2.5-VL using VisionZip modeling")
# else:
# BaseCls = backbone2model[model_backbone]
# print_master(f"[VisionCompression] Qwen2.5-VL using vanilla backbone (mode={mode})")
# # =============================
# base_model = BaseCls.from_pretrained(
# model_args.model_name,
# config=config,
# torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
# )
# elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]:
# config._attn_implementation = "flash_attention_2"
# config.padding_side = "left"
# config.use_cache = False
# from .utils import parse_layer_type
# lm_qwen_layer = 28
# vis_qwen_layer = 32
# lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer)
# vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer)
# base_model = backbone2model[model_backbone].from_pretrained(
# model_args.model_name,
# config=config,
# torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
# lm_skip_layer=lm_skip_layer,
# vis_skip_layer=vis_skip_layer,
# )
# else:
# config.use_cache = False
# base_model = cls.TRANSFORMER_CLS.from_pretrained(
# model_args.model_name, **kwargs, config=config,
# attn_implementation="flash_attention_2",
# torch_dtype=torch.bfloat16,
# trust_remote_code=True
# )
# # <-- call after base_model is assigned
# _ensure_pad_token_id_on_model(base_model)
# # Build MMEBModel
# if model_args.lora:
# print_master(f'Loading lora adapter from {base_model}')
# lora_config = LoraConfig(
# r=model_args.lora_r,
# lora_alpha=model_args.lora_alpha,
# target_modules=model_args.lora_target_modules.split(','),
# lora_dropout=model_args.lora_dropout,
# init_lora_weights="gaussian",
# use_dora=True,
# inference_mode=False
# )
# lora_model = get_peft_model(base_model, lora_config)
# model = cls(
# encoder=lora_model,
# pooling=model_args.pooling,
# normalize=model_args.normalize,
# temperature=model_args.temperature
# )
# else:
# model = cls(
# encoder=base_model,
# pooling=model_args.pooling,
# normalize=model_args.normalize,
# temperature=model_args.temperature
# )
# return model
# @classmethod
# def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs):
# # Loading the base model
# model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name
# config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
# if not hasattr(model_args, "model_backbone") or not model_args.model_backbone:
# model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type)
# setattr(model_args, 'model_backbone', model_backbone)
# print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}')
# base_model = None # <-- ensure defined before branches
# if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}:
# config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# config._attn_implementation = "flash_attention_2"
# if hasattr(config, "vision_config") and config.vision_config is not None:
# config.vision_config._attn_implementation = "flash_attention_2"
# mode = getattr(model_args, "vision_compression", "token_pooling")
# # ========= Qwen2-VL =========
# if model_args.model_backbone == QWEN2_VL:
# if mode == "token_pooling":
# BaseCls = Qwen2VLForConditionalGenerationTokenPooling
# print_master("[VisionCompression-load] Qwen2-VL using TokenPooling modeling")
# elif mode == "visionzip":
# BaseCls = Qwen2VLForConditionalGenerationVisionZip
# print_master("[VisionCompression-load] Qwen2-VL using VisionZip modeling")
# else:
# BaseCls = backbone2model[model_args.model_backbone]
# print_master(f"[VisionCompression-load] Qwen2-VL using vanilla backbone (mode={mode})")
# # ========= Qwen2.5-VL =========
# elif model_args.model_backbone == QWEN2_5_VL:
# if mode == "token_pooling":
# BaseCls = Qwen2_5VLForConditionalGenerationTokenPooling
# print_master("[VisionCompression-load] Qwen2.5-VL using TokenPooling modeling")
# elif mode == "visionzip":
# BaseCls = Qwen2_5VLForConditionalGenerationVisionZip
# print_master("[VisionCompression-load] Qwen2.5-VL using VisionZip modeling")
# else:
# BaseCls = backbone2model[model_args.model_backbone]
# print_master(f"[VisionCompression-load] Qwen2.5-VL using vanilla backbone (mode={mode})")
# # 其它 backbone 走原来的 mapping
# else:
# BaseCls = backbone2model[model_args.model_backbone]
# base_model = BaseCls.from_pretrained(
# model_args.model_name,
# torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
# config=config,
# )
# elif model_args.model_backbone == PHI3V:
# config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# config.use_cache = False
# config.padding_side = "right"
# base_model = Phi3VForCausalLM.from_pretrained(
# model_args.model_name, **kwargs, config=config,
# torch_dtype=torch.bfloat16, trust_remote_code=True
# )
# base_model.padding_side = "right"
# elif model_args.model_backbone == INTERNVIDEO2:
# print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}')
# config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", trust_remote_code=True)
# base_model = backbone2model[model_args.model_backbone].from_pretrained(
# "src/model/vlm_backbone/internvideo2/", config=config, trust_remote_code=True
# )
# elif model_args.model_backbone == GME:
# base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor'])
# setattr(base_model, 'config', config)
# elif model_args.model_backbone == LamRA:
# base_model = LamRAQwen2VL(model_args.model_name)
# setattr(base_model, 'config', config)
# elif model_args.model_backbone == LamRA_QWEN2_5:
# base_model = LamRAQwen25VL(model_args.model_name)
# setattr(base_model, 'config', config)
# elif model_args.model_backbone == COLPALI:
# base_model = ColPali.from_pretrained(model_args.model_name)
# setattr(base_model, 'config', config)
# else:
# # Loading external base model from HF
# config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# config.use_cache = False
# base_model = cls.TRANSFORMER_CLS.from_pretrained(
# model_name_or_path, **kwargs, config=config,
# torch_dtype=torch.bfloat16,
# trust_remote_code=True
# )
# # <-- call after base_model is assigned
# _ensure_pad_token_id_on_model(base_model)
# # Building the model on top of the base
# if model_args.lora:
# print_master(f'Loading LoRA from {model_name_or_path}')
# lora_config = LoraConfig.from_pretrained(model_name_or_path)
# lora_model = PeftModel.from_pretrained(
# base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable
# )
# lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable)
# if not is_trainable:
# lora_model = lora_model.merge_and_unload()
# model = cls(
# encoder=lora_model,
# pooling=model_args.pooling,
# normalize=model_args.normalize,
# temperature=model_args.temperature
# )
# else:
# model = cls(
# encoder=base_model,
# pooling=model_args.pooling,
# normalize=model_args.normalize,
# temperature=model_args.temperature
# )
# model.model_backbone = model_args.model_backbone
# return model
# def save(self, output_dir: str):
# self.encoder.save_pretrained(output_dir)
# def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs):
# qry_reps = self.encode_input(qry) if qry else None # (bsz_per_device, dim)
# tgt_reps = self.encode_input(tgt) if tgt else None # (bsz_per_device, dim)
# if qry_reps is None or tgt_reps is None:
# return {"qry_reps": qry_reps, "tgt_reps": tgt_reps}
# if self.is_ddp:
# all_qry_reps = self._dist_gather_tensor(qry_reps)
# all_tgt_reps = self._dist_gather_tensor(tgt_reps)
# else:
# all_qry_reps = qry_reps
# all_tgt_reps = tgt_reps
# scores = self.compute_similarity(all_qry_reps, all_tgt_reps)
# scores = scores.view(all_qry_reps.size(0), -1)
# target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
# target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0))
# loss = self.cross_entropy(scores / self.temperature, target)
# if self.is_ddp:
# loss = loss * self.world_size
# return loss
# def _dist_gather_tensor(self, t: Tensor):
# t = t.contiguous()
# all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
# dist.all_gather(all_tensors, t)
# all_tensors[self.process_rank] = t
# all_tensors = torch.cat(all_tensors, dim=0)
# return all_tensors
# def compute_similarity(self, q_reps, p_reps):
# return torch.matmul(q_reps, p_reps.transpose(0, 1))
from typing import Dict
import torch
import torch.distributed as dist
from torch import nn, Tensor
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig
from peft import LoraConfig, get_peft_model, PeftModel
from src.model.processor import QWEN2_5_VL_TOKENSELECTION
from src.arguments import ModelArguments, TrainingArguments
from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \
backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V
from src.arguments import ModelArguments
from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \
QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI
# 新增:分别引入 TokenPooling 版和 VisionZip 版
from src.model.vlm_backbone.qwen2_vl_token_pooling.modeling_qwen2_vl import (
Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationTokenPooling,
)
from src.model.vlm_backbone.qwen2_vl_visionzip.modeling_qwen2_vl import (
Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationVisionZip,
)
from src.model.vlm_backbone.qwen2_5_vl_token_pooling.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration as Qwen2_5VLForConditionalGenerationTokenPooling,
)
from src.model.vlm_backbone.qwen2_5_vl_visionzip.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration as Qwen2_5VLForConditionalGenerationVisionZip,
)
from src.model.baseline_backbone.colpali import ColPali
from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL
from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL
from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL
from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM
from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration
from transformers import modeling_utils
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise']
def _ensure_pad_token_id_on_model(base_model):
"""
Ensure base_model.config.pad_token_id is a valid int.
Fallback order: config.pad_token_id -> config.eos_token_id -> 0
Also sync generation_config.pad_token_id if present.
"""
pad_id = getattr(base_model.config, "pad_token_id", None)
if pad_id is None:
pad_id = getattr(base_model.config, "eos_token_id", None)
if pad_id is None:
pad_id = 0
base_model.config.pad_token_id = pad_id
gen_cfg = getattr(base_model, "generation_config", None)
if gen_cfg is not None and getattr(gen_cfg, "pad_token_id", None) is None:
gen_cfg.pad_token_id = base_model.config.pad_token_id
class MMEBModel(nn.Module):
TRANSFORMER_CLS = AutoModelForCausalLM
def __init__(self,
encoder: PreTrainedModel,
pooling: str = 'last',
normalize: bool = False,
temperature: float = 0.02,
):
super().__init__()
self.config = encoder.config
self.encoder = encoder
self.pooling = pooling
self.normalize = normalize
self.temperature = temperature
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.is_ddp = dist.is_initialized()
if self.is_ddp:
self.process_rank = dist.get_rank()
self.world_size = dist.get_world_size()
@property
def device(self):
try:
return next(self.parameters()).device
except StopIteration:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def encode_input(self, input):
if getattr(self, "model_backbone", None) == INTERNVIDEO2:
if "input_ids" in input.keys():
# text side
text_output = self.encoder.get_text_encoder()(
input["input_ids"],
attention_mask=input["attention_mask"],
return_dict=True,
mode="text",
)
text_embeds = text_output.last_hidden_state
pooled_text_embeds = text_embeds[:, 0]
pooled_output = self.encoder.text_proj(pooled_text_embeds)
pooled_output /= pooled_output.norm(dim=-1, keepdim=True)
return pooled_output
else:
_, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True)
vfeat = self.encoder.vision_proj(vfeat)
vfeat /= vfeat.norm(dim=-1, keepdim=True)
return vfeat
elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]:
# pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen
images = []
for imgs in input['images']:
# if multi images are given, select the middle frame only
if isinstance(imgs, list):
imgs = imgs[len(imgs) // 2]
assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list
images.append(imgs)
else:
images.append(imgs)
pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images)
return pooled_output
elif getattr(self, "model_backbone", None) == COLPALI:
pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
return pooled_output
elif getattr(self, "model_backbone", None) == LLAVA_NEXT:
input['pixel_values'] = input['pixel_values'].squeeze(dim=1)
input['image_sizes'] = input['image_sizes'].squeeze(dim=1)
hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True)
hidden_states = hidden_states.hidden_states[-1]
pooled_output = self._pooling(hidden_states, input['attention_mask'])
return pooled_output
else:
outputs = self.encoder(**input, return_dict=True, output_hidden_states=True)
last_hidden = outputs.hidden_states[-1] # [B, L', D](VisionZip 后 L' 可能变短)
# 优先使用模型 forward 返回的 post attention_mask;没有则回退到输入 mask
post_mask = getattr(outputs, "attention_mask", None)
src_mask = input.get("attention_mask", None)
use_mask = post_mask if (post_mask is not None) else src_mask
pooled_output = self._pooling(last_hidden, use_mask)
return pooled_output
def _pooling(self, last_hidden_state, attention_mask):
"""
健壮的 eos pooling:
- 若 attention_mask 为空或长度与 last_hidden_state 不一致,则回退到每样本最后一位(左 padding 默认成立)
- 正常情况下用 mask.sum(dim=1)-1 取有效最后位,并做 clamp 防越界
"""
if self.pooling in ('last', 'eos'):
B, L, D = last_hidden_state.shape
device = last_hidden_state.device
# 回退条件:无 mask 或长度不匹配
if (attention_mask is None) or (attention_mask.shape[1] != L):
reps = last_hidden_state[:, -1, :]
else:
# 计算每行有效长度(>=1),并转换为有效索引 [0, L-1]
# 注意:attention_mask 可能是 float/bfloat16,统一转 long 计算
valid_len = attention_mask.to(torch.long).sum(dim=1) # [B]
eos_idx = (valid_len - 1).clamp(min=0, max=L - 1) # [B]
reps = last_hidden_state[torch.arange(B, device=device), eos_idx, :]
else:
raise NotImplementedError
if self.normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
@classmethod
def build(cls, model_args: ModelArguments, **kwargs):
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
model_backbone = get_backbone_name(hf_config=config)
print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}')
base_model = None # <-- ensure defined before branches
# Loading the base model
if model_backbone == PHI3V:
config._attn_implementation = "eager"
config.padding_side = "right"
config.use_cache = False
base_model = Phi3VForCausalLM.from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone == LLAVA_NEXT:
config.use_cache = False
config.padding_side = "left"
base_model = LlavaNextForConditionalGeneration.from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in [QWEN2_VL, QWEN2_5_VL]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
mode = getattr(model_args, "vision_compression", "token_pooling")
# ========= Qwen2-VL =========
if model_backbone == QWEN2_VL:
if mode == "token_pooling":
BaseCls = Qwen2VLForConditionalGenerationTokenPooling
print_master("[VisionCompression] Qwen2-VL using TokenPooling modeling")
elif mode == "visionzip":
BaseCls = Qwen2VLForConditionalGenerationVisionZip
print_master("[VisionCompression] Qwen2-VL using VisionZip modeling")
else: # "none" 或未知
BaseCls = backbone2model[model_backbone]
print_master(f"[VisionCompression] Qwen2-VL using vanilla backbone (mode={mode})")
# ========= Qwen2.5-VL =========
elif model_backbone == QWEN2_5_VL:
if mode == "token_pooling":
BaseCls = Qwen2_5VLForConditionalGenerationTokenPooling
print_master("[VisionCompression] Qwen2.5-VL using TokenPooling modeling")
elif mode == "visionzip":
BaseCls = Qwen2_5VLForConditionalGenerationVisionZip
print_master("[VisionCompression] Qwen2.5-VL using VisionZip modeling")
else:
BaseCls = backbone2model[model_backbone]
print_master(f"[VisionCompression] Qwen2.5-VL using vanilla backbone (mode={mode})")
# =============================
base_model = BaseCls.from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
from .utils import parse_layer_type
lm_qwen_layer = 28
vis_qwen_layer = 32
lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer)
vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer)
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
lm_skip_layer=lm_skip_layer,
vis_skip_layer=vis_skip_layer,
)
else:
config.use_cache = False
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_args.model_name, **kwargs, config=config,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
# <-- call after base_model is assigned
_ensure_pad_token_id_on_model(base_model)
# Build MMEBModel
if model_args.lora:
print_master(f'Loading lora adapter from {base_model}')
lora_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
target_modules=model_args.lora_target_modules.split(','),
lora_dropout=model_args.lora_dropout,
init_lora_weights="gaussian",
use_dora=True,
inference_mode=False
)
lora_model = get_peft_model(base_model, lora_config)
model = cls(
encoder=lora_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
else:
model = cls(
encoder=base_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
return model
@classmethod
def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs):
# Loading the base model
model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
if not hasattr(model_args, "model_backbone") or not model_args.model_backbone:
model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}')
base_model = None # <-- ensure defined before branches
if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}:
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config._attn_implementation = "flash_attention_2"
if hasattr(config, "vision_config") and config.vision_config is not None:
config.vision_config._attn_implementation = "flash_attention_2"
mode = getattr(model_args, "vision_compression", "token_pooling")
# ========= Qwen2-VL =========
if model_args.model_backbone == QWEN2_VL:
if mode == "token_pooling":
BaseCls = Qwen2VLForConditionalGenerationTokenPooling
print_master("[VisionCompression-load] Qwen2-VL using TokenPooling modeling")
elif mode == "visionzip":
BaseCls = Qwen2VLForConditionalGenerationVisionZip
print_master("[VisionCompression-load] Qwen2-VL using VisionZip modeling")
else:
BaseCls = backbone2model[model_args.model_backbone]
print_master(f"[VisionCompression-load] Qwen2-VL using vanilla backbone (mode={mode})")
# ========= Qwen2.5-VL =========
elif model_args.model_backbone == QWEN2_5_VL:
if mode == "token_pooling":
BaseCls = Qwen2_5VLForConditionalGenerationTokenPooling
print_master("[VisionCompression-load] Qwen2.5-VL using TokenPooling modeling")
elif mode == "visionzip":
BaseCls = Qwen2_5VLForConditionalGenerationVisionZip
print_master("[VisionCompression-load] Qwen2.5-VL using VisionZip modeling")
else:
BaseCls = backbone2model[model_args.model_backbone]
print_master(f"[VisionCompression-load] Qwen2.5-VL using vanilla backbone (mode={mode})")
# 其它 backbone 走原来的 mapping
else:
BaseCls = backbone2model[model_args.model_backbone]
base_model = BaseCls.from_pretrained(
model_args.model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
config=config,
)
elif model_args.model_backbone == PHI3V:
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config.use_cache = False
config.padding_side = "right"
base_model = Phi3VForCausalLM.from_pretrained(
model_args.model_name, **kwargs, config=config,
torch_dtype=torch.bfloat16, trust_remote_code=True
)
base_model.padding_side = "right"
elif model_args.model_backbone == INTERNVIDEO2:
print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}')
config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", trust_remote_code=True)
base_model = backbone2model[model_args.model_backbone].from_pretrained(
"src/model/vlm_backbone/internvideo2/", config=config, trust_remote_code=True
)
elif model_args.model_backbone == GME:
base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor'])
setattr(base_model, 'config', config)
elif model_args.model_backbone == LamRA:
base_model = LamRAQwen2VL(model_args.model_name)
setattr(base_model, 'config', config)
elif model_args.model_backbone == LamRA_QWEN2_5:
base_model = LamRAQwen25VL(model_args.model_name)
setattr(base_model, 'config', config)
elif model_args.model_backbone == COLPALI:
base_model = ColPali.from_pretrained(model_args.model_name)
setattr(base_model, 'config', config)
else:
# Loading external base model from HF
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config.use_cache = False
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_name_or_path, **kwargs, config=config,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
# <-- call after base_model is assigned
_ensure_pad_token_id_on_model(base_model)
# Building the model on top of the base
if model_args.lora:
print_master(f'Loading LoRA from {model_name_or_path}')
lora_config = LoraConfig.from_pretrained(model_name_or_path)
lora_model = PeftModel.from_pretrained(
base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable
)
lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable)
if not is_trainable:
lora_model = lora_model.merge_and_unload()
model = cls(
encoder=lora_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
else:
model = cls(
encoder=base_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
model.model_backbone = model_args.model_backbone
return model
def save(self, output_dir: str):
self.encoder.save_pretrained(output_dir)
def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs):
qry_reps = self.encode_input(qry) if qry else None # (bsz_per_device, dim)
tgt_reps = self.encode_input(tgt) if tgt else None # (bsz_per_device, dim)
if qry_reps is None or tgt_reps is None:
return {"qry_reps": qry_reps, "tgt_reps": tgt_reps}
if self.is_ddp:
all_qry_reps = self._dist_gather_tensor(qry_reps)
all_tgt_reps = self._dist_gather_tensor(tgt_reps)
else:
all_qry_reps = qry_reps
all_tgt_reps = tgt_reps
scores = self.compute_similarity(all_qry_reps, all_tgt_reps)
scores = scores.view(all_qry_reps.size(0), -1)
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0))
loss = self.cross_entropy(scores / self.temperature, target)
if self.is_ddp:
loss = loss * self.world_size
return loss
def _dist_gather_tensor(self, t: Tensor):
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)
all_tensors[self.process_rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
def compute_similarity(self, q_reps, p_reps):
return torch.matmul(q_reps, p_reps.transpose(0, 1))