| # from typing import Dict | |
| # import torch | |
| # import torch.distributed as dist | |
| # from torch import nn, Tensor | |
| # from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig | |
| # from peft import LoraConfig, get_peft_model, PeftModel | |
| # from src.model.processor import QWEN2_5_VL_TOKENSELECTION | |
| # from src.arguments import ModelArguments, TrainingArguments | |
| # from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \ | |
| # backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V | |
| # from src.arguments import ModelArguments | |
| # from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \ | |
| # QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI | |
| # # 新增:分别引入 TokenPooling 版和 VisionZip 版 | |
| # from src.model.vlm_backbone.qwen2_vl_token_pooling.modeling_qwen2_vl import ( | |
| # Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationTokenPooling, | |
| # ) | |
| # from src.model.vlm_backbone.qwen2_vl_visionzip.modeling_qwen2_vl import ( | |
| # Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationVisionZip, | |
| # ) | |
| # from src.model.vlm_backbone.qwen2_5_vl_token_pooling.modeling_qwen2_5_vl import ( | |
| # Qwen2_5_VLForConditionalGeneration as Qwen2_5VLForConditionalGenerationTokenPooling, | |
| # ) | |
| # from src.model.vlm_backbone.qwen2_5_vl_visionzip.modeling_qwen2_5_vl import ( | |
| # Qwen2_5_VLForConditionalGeneration as Qwen2_5VLForConditionalGenerationVisionZip, | |
| # ) | |
| # from src.model.baseline_backbone.colpali import ColPali | |
| # from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL | |
| # from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL | |
| # from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL | |
| # from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM | |
| # from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration | |
| # from transformers import modeling_utils | |
| # if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: | |
| # modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise'] | |
| # def _ensure_pad_token_id_on_model(base_model): | |
| # """ | |
| # Ensure base_model.config.pad_token_id is a valid int. | |
| # Fallback order: config.pad_token_id -> config.eos_token_id -> 0 | |
| # Also sync generation_config.pad_token_id if present. | |
| # """ | |
| # pad_id = getattr(base_model.config, "pad_token_id", None) | |
| # if pad_id is None: | |
| # pad_id = getattr(base_model.config, "eos_token_id", None) | |
| # if pad_id is None: | |
| # pad_id = 0 | |
| # base_model.config.pad_token_id = pad_id | |
| # gen_cfg = getattr(base_model, "generation_config", None) | |
| # if gen_cfg is not None and getattr(gen_cfg, "pad_token_id", None) is None: | |
| # gen_cfg.pad_token_id = base_model.config.pad_token_id | |
| # class MMEBModel(nn.Module): | |
| # TRANSFORMER_CLS = AutoModelForCausalLM | |
| # def __init__(self, | |
| # encoder: PreTrainedModel, | |
| # pooling: str = 'last', | |
| # normalize: bool = False, | |
| # temperature: float = 0.02, | |
| # ): | |
| # super().__init__() | |
| # self.config = encoder.config | |
| # self.encoder = encoder | |
| # self.pooling = pooling | |
| # self.normalize = normalize | |
| # self.temperature = temperature | |
| # self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') | |
| # self.is_ddp = dist.is_initialized() | |
| # if self.is_ddp: | |
| # self.process_rank = dist.get_rank() | |
| # self.world_size = dist.get_world_size() | |
| # @property | |
| # def device(self): | |
| # try: | |
| # return next(self.parameters()).device | |
| # except StopIteration: | |
| # return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # def encode_input(self, input): | |
| # if getattr(self, "model_backbone", None) == INTERNVIDEO2: | |
| # if "input_ids" in input.keys(): | |
| # # text side | |
| # text_output = self.encoder.get_text_encoder()( | |
| # input["input_ids"], | |
| # attention_mask=input["attention_mask"], | |
| # return_dict=True, | |
| # mode="text", | |
| # ) | |
| # text_embeds = text_output.last_hidden_state | |
| # pooled_text_embeds = text_embeds[:, 0] | |
| # pooled_output = self.encoder.text_proj(pooled_text_embeds) | |
| # pooled_output /= pooled_output.norm(dim=-1, keepdim=True) | |
| # return pooled_output | |
| # else: | |
| # _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) | |
| # vfeat = self.encoder.vision_proj(vfeat) | |
| # vfeat /= vfeat.norm(dim=-1, keepdim=True) | |
| # return vfeat | |
| # elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]: | |
| # # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) | |
| # texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen | |
| # images = [] | |
| # for imgs in input['images']: | |
| # # if multi images are given, select the middle frame only | |
| # if isinstance(imgs, list): | |
| # imgs = imgs[len(imgs) // 2] | |
| # assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list | |
| # images.append(imgs) | |
| # else: | |
| # images.append(imgs) | |
| # pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) | |
| # return pooled_output | |
| # elif getattr(self, "model_backbone", None) == COLPALI: | |
| # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) | |
| # return pooled_output | |
| # elif getattr(self, "model_backbone", None) == LLAVA_NEXT: | |
| # input['pixel_values'] = input['pixel_values'].squeeze(dim=1) | |
| # input['image_sizes'] = input['image_sizes'].squeeze(dim=1) | |
| # hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) | |
| # hidden_states = hidden_states.hidden_states[-1] | |
| # pooled_output = self._pooling(hidden_states, input['attention_mask']) | |
| # return pooled_output | |
| # else: | |
| # hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) | |
| # hidden_states = hidden_states.hidden_states[-1] | |
| # pooled_output = self._pooling(hidden_states, input['attention_mask']) | |
| # return pooled_output | |
| # def _pooling(self, last_hidden_state, attention_mask): | |
| # if self.pooling == 'last' or self.pooling == 'eos': | |
| # left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) | |
| # batch_size = last_hidden_state.shape[0] | |
| # if left_padding: | |
| # # Get the vectors at the last position | |
| # reps = last_hidden_state[torch.arange(batch_size), -1, :] | |
| # else: | |
| # # Calculate last 1 position in the original tensor | |
| # eos_indices = attention_mask.sum(dim=1) - 1 | |
| # # Get the vectors at the last 1 position of each attention mask | |
| # reps = last_hidden_state[ | |
| # torch.arange(batch_size, device=last_hidden_state.device), eos_indices] | |
| # else: | |
| # raise NotImplementedError | |
| # if self.normalize: | |
| # reps = torch.nn.functional.normalize(reps, p=2, dim=-1) | |
| # return reps | |
| # @classmethod | |
| # def build(cls, model_args: ModelArguments, **kwargs): | |
| # config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| # model_backbone = get_backbone_name(hf_config=config) | |
| # print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}') | |
| # base_model = None # <-- ensure defined before branches | |
| # # Loading the base model | |
| # if model_backbone == PHI3V: | |
| # config._attn_implementation = "eager" | |
| # config.padding_side = "right" | |
| # config.use_cache = False | |
| # base_model = Phi3VForCausalLM.from_pretrained( | |
| # model_args.model_name, | |
| # config=config, | |
| # torch_dtype=torch.bfloat16, | |
| # low_cpu_mem_usage=True, | |
| # ) | |
| # elif model_backbone == LLAVA_NEXT: | |
| # config.use_cache = False | |
| # config.padding_side = "left" | |
| # base_model = LlavaNextForConditionalGeneration.from_pretrained( | |
| # model_args.model_name, | |
| # config=config, | |
| # torch_dtype=torch.bfloat16, | |
| # low_cpu_mem_usage=True, | |
| # ) | |
| # elif model_backbone in [QWEN2_VL, QWEN2_5_VL]: | |
| # config._attn_implementation = "flash_attention_2" | |
| # config.padding_side = "left" | |
| # config.use_cache = False | |
| # mode = getattr(model_args, "vision_compression", "token_pooling") | |
| # # ========= Qwen2-VL ========= | |
| # if model_backbone == QWEN2_VL: | |
| # if mode == "token_pooling": | |
| # BaseCls = Qwen2VLForConditionalGenerationTokenPooling | |
| # print_master("[VisionCompression] Qwen2-VL using TokenPooling modeling") | |
| # elif mode == "visionzip": | |
| # BaseCls = Qwen2VLForConditionalGenerationVisionZip | |
| # print_master("[VisionCompression] Qwen2-VL using VisionZip modeling") | |
| # else: # "none" 或未知 | |
| # BaseCls = backbone2model[model_backbone] | |
| # print_master(f"[VisionCompression] Qwen2-VL using vanilla backbone (mode={mode})") | |
| # # ========= Qwen2.5-VL ========= | |
| # elif model_backbone == QWEN2_5_VL: | |
| # if mode == "token_pooling": | |
| # BaseCls = Qwen2_5VLForConditionalGenerationTokenPooling | |
| # print_master("[VisionCompression] Qwen2.5-VL using TokenPooling modeling") | |
| # elif mode == "visionzip": | |
| # BaseCls = Qwen2_5VLForConditionalGenerationVisionZip | |
| # print_master("[VisionCompression] Qwen2.5-VL using VisionZip modeling") | |
| # else: | |
| # BaseCls = backbone2model[model_backbone] | |
| # print_master(f"[VisionCompression] Qwen2.5-VL using vanilla backbone (mode={mode})") | |
| # # ============================= | |
| # base_model = BaseCls.from_pretrained( | |
| # model_args.model_name, | |
| # config=config, | |
| # torch_dtype=torch.bfloat16, | |
| # low_cpu_mem_usage=True, | |
| # ) | |
| # elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]: | |
| # config._attn_implementation = "flash_attention_2" | |
| # config.padding_side = "left" | |
| # config.use_cache = False | |
| # from .utils import parse_layer_type | |
| # lm_qwen_layer = 28 | |
| # vis_qwen_layer = 32 | |
| # lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer) | |
| # vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer) | |
| # base_model = backbone2model[model_backbone].from_pretrained( | |
| # model_args.model_name, | |
| # config=config, | |
| # torch_dtype=torch.bfloat16, | |
| # low_cpu_mem_usage=True, | |
| # lm_skip_layer=lm_skip_layer, | |
| # vis_skip_layer=vis_skip_layer, | |
| # ) | |
| # else: | |
| # config.use_cache = False | |
| # base_model = cls.TRANSFORMER_CLS.from_pretrained( | |
| # model_args.model_name, **kwargs, config=config, | |
| # attn_implementation="flash_attention_2", | |
| # torch_dtype=torch.bfloat16, | |
| # trust_remote_code=True | |
| # ) | |
| # # <-- call after base_model is assigned | |
| # _ensure_pad_token_id_on_model(base_model) | |
| # # Build MMEBModel | |
| # if model_args.lora: | |
| # print_master(f'Loading lora adapter from {base_model}') | |
| # lora_config = LoraConfig( | |
| # r=model_args.lora_r, | |
| # lora_alpha=model_args.lora_alpha, | |
| # target_modules=model_args.lora_target_modules.split(','), | |
| # lora_dropout=model_args.lora_dropout, | |
| # init_lora_weights="gaussian", | |
| # use_dora=True, | |
| # inference_mode=False | |
| # ) | |
| # lora_model = get_peft_model(base_model, lora_config) | |
| # model = cls( | |
| # encoder=lora_model, | |
| # pooling=model_args.pooling, | |
| # normalize=model_args.normalize, | |
| # temperature=model_args.temperature | |
| # ) | |
| # else: | |
| # model = cls( | |
| # encoder=base_model, | |
| # pooling=model_args.pooling, | |
| # normalize=model_args.normalize, | |
| # temperature=model_args.temperature | |
| # ) | |
| # return model | |
| # @classmethod | |
| # def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs): | |
| # # Loading the base model | |
| # model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name | |
| # config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) | |
| # if not hasattr(model_args, "model_backbone") or not model_args.model_backbone: | |
| # model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type) | |
| # setattr(model_args, 'model_backbone', model_backbone) | |
| # print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}') | |
| # base_model = None # <-- ensure defined before branches | |
| # if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}: | |
| # config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| # config._attn_implementation = "flash_attention_2" | |
| # if hasattr(config, "vision_config") and config.vision_config is not None: | |
| # config.vision_config._attn_implementation = "flash_attention_2" | |
| # mode = getattr(model_args, "vision_compression", "token_pooling") | |
| # # ========= Qwen2-VL ========= | |
| # if model_args.model_backbone == QWEN2_VL: | |
| # if mode == "token_pooling": | |
| # BaseCls = Qwen2VLForConditionalGenerationTokenPooling | |
| # print_master("[VisionCompression-load] Qwen2-VL using TokenPooling modeling") | |
| # elif mode == "visionzip": | |
| # BaseCls = Qwen2VLForConditionalGenerationVisionZip | |
| # print_master("[VisionCompression-load] Qwen2-VL using VisionZip modeling") | |
| # else: | |
| # BaseCls = backbone2model[model_args.model_backbone] | |
| # print_master(f"[VisionCompression-load] Qwen2-VL using vanilla backbone (mode={mode})") | |
| # # ========= Qwen2.5-VL ========= | |
| # elif model_args.model_backbone == QWEN2_5_VL: | |
| # if mode == "token_pooling": | |
| # BaseCls = Qwen2_5VLForConditionalGenerationTokenPooling | |
| # print_master("[VisionCompression-load] Qwen2.5-VL using TokenPooling modeling") | |
| # elif mode == "visionzip": | |
| # BaseCls = Qwen2_5VLForConditionalGenerationVisionZip | |
| # print_master("[VisionCompression-load] Qwen2.5-VL using VisionZip modeling") | |
| # else: | |
| # BaseCls = backbone2model[model_args.model_backbone] | |
| # print_master(f"[VisionCompression-load] Qwen2.5-VL using vanilla backbone (mode={mode})") | |
| # # 其它 backbone 走原来的 mapping | |
| # else: | |
| # BaseCls = backbone2model[model_args.model_backbone] | |
| # base_model = BaseCls.from_pretrained( | |
| # model_args.model_name, | |
| # torch_dtype=torch.bfloat16, | |
| # low_cpu_mem_usage=True, | |
| # config=config, | |
| # ) | |
| # elif model_args.model_backbone == PHI3V: | |
| # config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| # config.use_cache = False | |
| # config.padding_side = "right" | |
| # base_model = Phi3VForCausalLM.from_pretrained( | |
| # model_args.model_name, **kwargs, config=config, | |
| # torch_dtype=torch.bfloat16, trust_remote_code=True | |
| # ) | |
| # base_model.padding_side = "right" | |
| # elif model_args.model_backbone == INTERNVIDEO2: | |
| # print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}') | |
| # config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", trust_remote_code=True) | |
| # base_model = backbone2model[model_args.model_backbone].from_pretrained( | |
| # "src/model/vlm_backbone/internvideo2/", config=config, trust_remote_code=True | |
| # ) | |
| # elif model_args.model_backbone == GME: | |
| # base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor']) | |
| # setattr(base_model, 'config', config) | |
| # elif model_args.model_backbone == LamRA: | |
| # base_model = LamRAQwen2VL(model_args.model_name) | |
| # setattr(base_model, 'config', config) | |
| # elif model_args.model_backbone == LamRA_QWEN2_5: | |
| # base_model = LamRAQwen25VL(model_args.model_name) | |
| # setattr(base_model, 'config', config) | |
| # elif model_args.model_backbone == COLPALI: | |
| # base_model = ColPali.from_pretrained(model_args.model_name) | |
| # setattr(base_model, 'config', config) | |
| # else: | |
| # # Loading external base model from HF | |
| # config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| # config.use_cache = False | |
| # base_model = cls.TRANSFORMER_CLS.from_pretrained( | |
| # model_name_or_path, **kwargs, config=config, | |
| # torch_dtype=torch.bfloat16, | |
| # trust_remote_code=True | |
| # ) | |
| # # <-- call after base_model is assigned | |
| # _ensure_pad_token_id_on_model(base_model) | |
| # # Building the model on top of the base | |
| # if model_args.lora: | |
| # print_master(f'Loading LoRA from {model_name_or_path}') | |
| # lora_config = LoraConfig.from_pretrained(model_name_or_path) | |
| # lora_model = PeftModel.from_pretrained( | |
| # base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable | |
| # ) | |
| # lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable) | |
| # if not is_trainable: | |
| # lora_model = lora_model.merge_and_unload() | |
| # model = cls( | |
| # encoder=lora_model, | |
| # pooling=model_args.pooling, | |
| # normalize=model_args.normalize, | |
| # temperature=model_args.temperature | |
| # ) | |
| # else: | |
| # model = cls( | |
| # encoder=base_model, | |
| # pooling=model_args.pooling, | |
| # normalize=model_args.normalize, | |
| # temperature=model_args.temperature | |
| # ) | |
| # model.model_backbone = model_args.model_backbone | |
| # return model | |
| # def save(self, output_dir: str): | |
| # self.encoder.save_pretrained(output_dir) | |
| # def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs): | |
| # qry_reps = self.encode_input(qry) if qry else None # (bsz_per_device, dim) | |
| # tgt_reps = self.encode_input(tgt) if tgt else None # (bsz_per_device, dim) | |
| # if qry_reps is None or tgt_reps is None: | |
| # return {"qry_reps": qry_reps, "tgt_reps": tgt_reps} | |
| # if self.is_ddp: | |
| # all_qry_reps = self._dist_gather_tensor(qry_reps) | |
| # all_tgt_reps = self._dist_gather_tensor(tgt_reps) | |
| # else: | |
| # all_qry_reps = qry_reps | |
| # all_tgt_reps = tgt_reps | |
| # scores = self.compute_similarity(all_qry_reps, all_tgt_reps) | |
| # scores = scores.view(all_qry_reps.size(0), -1) | |
| # target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) | |
| # target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0)) | |
| # loss = self.cross_entropy(scores / self.temperature, target) | |
| # if self.is_ddp: | |
| # loss = loss * self.world_size | |
| # return loss | |
| # def _dist_gather_tensor(self, t: Tensor): | |
| # t = t.contiguous() | |
| # all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] | |
| # dist.all_gather(all_tensors, t) | |
| # all_tensors[self.process_rank] = t | |
| # all_tensors = torch.cat(all_tensors, dim=0) | |
| # return all_tensors | |
| # def compute_similarity(self, q_reps, p_reps): | |
| # return torch.matmul(q_reps, p_reps.transpose(0, 1)) | |
| from typing import Dict | |
| import torch | |
| import torch.distributed as dist | |
| from torch import nn, Tensor | |
| from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig | |
| from peft import LoraConfig, get_peft_model, PeftModel | |
| from src.model.processor import QWEN2_5_VL_TOKENSELECTION | |
| from src.arguments import ModelArguments, TrainingArguments | |
| from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \ | |
| backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V | |
| from src.arguments import ModelArguments | |
| from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \ | |
| QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI | |
| # 新增:分别引入 TokenPooling 版和 VisionZip 版 | |
| from src.model.vlm_backbone.qwen2_vl_token_pooling.modeling_qwen2_vl import ( | |
| Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationTokenPooling, | |
| ) | |
| from src.model.vlm_backbone.qwen2_vl_visionzip.modeling_qwen2_vl import ( | |
| Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationVisionZip, | |
| ) | |
| from src.model.vlm_backbone.qwen2_5_vl_token_pooling.modeling_qwen2_5_vl import ( | |
| Qwen2_5_VLForConditionalGeneration as Qwen2_5VLForConditionalGenerationTokenPooling, | |
| ) | |
| from src.model.vlm_backbone.qwen2_5_vl_visionzip.modeling_qwen2_5_vl import ( | |
| Qwen2_5_VLForConditionalGeneration as Qwen2_5VLForConditionalGenerationVisionZip, | |
| ) | |
| from src.model.baseline_backbone.colpali import ColPali | |
| from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL | |
| from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL | |
| from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL | |
| from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM | |
| from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration | |
| from transformers import modeling_utils | |
| if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: | |
| modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise'] | |
| def _ensure_pad_token_id_on_model(base_model): | |
| """ | |
| Ensure base_model.config.pad_token_id is a valid int. | |
| Fallback order: config.pad_token_id -> config.eos_token_id -> 0 | |
| Also sync generation_config.pad_token_id if present. | |
| """ | |
| pad_id = getattr(base_model.config, "pad_token_id", None) | |
| if pad_id is None: | |
| pad_id = getattr(base_model.config, "eos_token_id", None) | |
| if pad_id is None: | |
| pad_id = 0 | |
| base_model.config.pad_token_id = pad_id | |
| gen_cfg = getattr(base_model, "generation_config", None) | |
| if gen_cfg is not None and getattr(gen_cfg, "pad_token_id", None) is None: | |
| gen_cfg.pad_token_id = base_model.config.pad_token_id | |
| class MMEBModel(nn.Module): | |
| TRANSFORMER_CLS = AutoModelForCausalLM | |
| def __init__(self, | |
| encoder: PreTrainedModel, | |
| pooling: str = 'last', | |
| normalize: bool = False, | |
| temperature: float = 0.02, | |
| ): | |
| super().__init__() | |
| self.config = encoder.config | |
| self.encoder = encoder | |
| self.pooling = pooling | |
| self.normalize = normalize | |
| self.temperature = temperature | |
| self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') | |
| self.is_ddp = dist.is_initialized() | |
| if self.is_ddp: | |
| self.process_rank = dist.get_rank() | |
| self.world_size = dist.get_world_size() | |
| def device(self): | |
| try: | |
| return next(self.parameters()).device | |
| except StopIteration: | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def encode_input(self, input): | |
| if getattr(self, "model_backbone", None) == INTERNVIDEO2: | |
| if "input_ids" in input.keys(): | |
| # text side | |
| text_output = self.encoder.get_text_encoder()( | |
| input["input_ids"], | |
| attention_mask=input["attention_mask"], | |
| return_dict=True, | |
| mode="text", | |
| ) | |
| text_embeds = text_output.last_hidden_state | |
| pooled_text_embeds = text_embeds[:, 0] | |
| pooled_output = self.encoder.text_proj(pooled_text_embeds) | |
| pooled_output /= pooled_output.norm(dim=-1, keepdim=True) | |
| return pooled_output | |
| else: | |
| _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) | |
| vfeat = self.encoder.vision_proj(vfeat) | |
| vfeat /= vfeat.norm(dim=-1, keepdim=True) | |
| return vfeat | |
| elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]: | |
| # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) | |
| texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen | |
| images = [] | |
| for imgs in input['images']: | |
| # if multi images are given, select the middle frame only | |
| if isinstance(imgs, list): | |
| imgs = imgs[len(imgs) // 2] | |
| assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list | |
| images.append(imgs) | |
| else: | |
| images.append(imgs) | |
| pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) | |
| return pooled_output | |
| elif getattr(self, "model_backbone", None) == COLPALI: | |
| pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) | |
| return pooled_output | |
| elif getattr(self, "model_backbone", None) == LLAVA_NEXT: | |
| input['pixel_values'] = input['pixel_values'].squeeze(dim=1) | |
| input['image_sizes'] = input['image_sizes'].squeeze(dim=1) | |
| hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) | |
| hidden_states = hidden_states.hidden_states[-1] | |
| pooled_output = self._pooling(hidden_states, input['attention_mask']) | |
| return pooled_output | |
| else: | |
| outputs = self.encoder(**input, return_dict=True, output_hidden_states=True) | |
| last_hidden = outputs.hidden_states[-1] # [B, L', D](VisionZip 后 L' 可能变短) | |
| # 优先使用模型 forward 返回的 post attention_mask;没有则回退到输入 mask | |
| post_mask = getattr(outputs, "attention_mask", None) | |
| src_mask = input.get("attention_mask", None) | |
| use_mask = post_mask if (post_mask is not None) else src_mask | |
| pooled_output = self._pooling(last_hidden, use_mask) | |
| return pooled_output | |
| def _pooling(self, last_hidden_state, attention_mask): | |
| """ | |
| 健壮的 eos pooling: | |
| - 若 attention_mask 为空或长度与 last_hidden_state 不一致,则回退到每样本最后一位(左 padding 默认成立) | |
| - 正常情况下用 mask.sum(dim=1)-1 取有效最后位,并做 clamp 防越界 | |
| """ | |
| if self.pooling in ('last', 'eos'): | |
| B, L, D = last_hidden_state.shape | |
| device = last_hidden_state.device | |
| # 回退条件:无 mask 或长度不匹配 | |
| if (attention_mask is None) or (attention_mask.shape[1] != L): | |
| reps = last_hidden_state[:, -1, :] | |
| else: | |
| # 计算每行有效长度(>=1),并转换为有效索引 [0, L-1] | |
| # 注意:attention_mask 可能是 float/bfloat16,统一转 long 计算 | |
| valid_len = attention_mask.to(torch.long).sum(dim=1) # [B] | |
| eos_idx = (valid_len - 1).clamp(min=0, max=L - 1) # [B] | |
| reps = last_hidden_state[torch.arange(B, device=device), eos_idx, :] | |
| else: | |
| raise NotImplementedError | |
| if self.normalize: | |
| reps = torch.nn.functional.normalize(reps, p=2, dim=-1) | |
| return reps | |
| def build(cls, model_args: ModelArguments, **kwargs): | |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| model_backbone = get_backbone_name(hf_config=config) | |
| print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}') | |
| base_model = None # <-- ensure defined before branches | |
| # Loading the base model | |
| if model_backbone == PHI3V: | |
| config._attn_implementation = "eager" | |
| config.padding_side = "right" | |
| config.use_cache = False | |
| base_model = Phi3VForCausalLM.from_pretrained( | |
| model_args.model_name, | |
| config=config, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| ) | |
| elif model_backbone == LLAVA_NEXT: | |
| config.use_cache = False | |
| config.padding_side = "left" | |
| base_model = LlavaNextForConditionalGeneration.from_pretrained( | |
| model_args.model_name, | |
| config=config, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| ) | |
| elif model_backbone in [QWEN2_VL, QWEN2_5_VL]: | |
| config._attn_implementation = "flash_attention_2" | |
| config.padding_side = "left" | |
| config.use_cache = False | |
| mode = getattr(model_args, "vision_compression", "token_pooling") | |
| # ========= Qwen2-VL ========= | |
| if model_backbone == QWEN2_VL: | |
| if mode == "token_pooling": | |
| BaseCls = Qwen2VLForConditionalGenerationTokenPooling | |
| print_master("[VisionCompression] Qwen2-VL using TokenPooling modeling") | |
| elif mode == "visionzip": | |
| BaseCls = Qwen2VLForConditionalGenerationVisionZip | |
| print_master("[VisionCompression] Qwen2-VL using VisionZip modeling") | |
| else: # "none" 或未知 | |
| BaseCls = backbone2model[model_backbone] | |
| print_master(f"[VisionCompression] Qwen2-VL using vanilla backbone (mode={mode})") | |
| # ========= Qwen2.5-VL ========= | |
| elif model_backbone == QWEN2_5_VL: | |
| if mode == "token_pooling": | |
| BaseCls = Qwen2_5VLForConditionalGenerationTokenPooling | |
| print_master("[VisionCompression] Qwen2.5-VL using TokenPooling modeling") | |
| elif mode == "visionzip": | |
| BaseCls = Qwen2_5VLForConditionalGenerationVisionZip | |
| print_master("[VisionCompression] Qwen2.5-VL using VisionZip modeling") | |
| else: | |
| BaseCls = backbone2model[model_backbone] | |
| print_master(f"[VisionCompression] Qwen2.5-VL using vanilla backbone (mode={mode})") | |
| # ============================= | |
| base_model = BaseCls.from_pretrained( | |
| model_args.model_name, | |
| config=config, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| ) | |
| elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]: | |
| config._attn_implementation = "flash_attention_2" | |
| config.padding_side = "left" | |
| config.use_cache = False | |
| from .utils import parse_layer_type | |
| lm_qwen_layer = 28 | |
| vis_qwen_layer = 32 | |
| lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer) | |
| vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer) | |
| base_model = backbone2model[model_backbone].from_pretrained( | |
| model_args.model_name, | |
| config=config, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| lm_skip_layer=lm_skip_layer, | |
| vis_skip_layer=vis_skip_layer, | |
| ) | |
| else: | |
| config.use_cache = False | |
| base_model = cls.TRANSFORMER_CLS.from_pretrained( | |
| model_args.model_name, **kwargs, config=config, | |
| attn_implementation="flash_attention_2", | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True | |
| ) | |
| # <-- call after base_model is assigned | |
| _ensure_pad_token_id_on_model(base_model) | |
| # Build MMEBModel | |
| if model_args.lora: | |
| print_master(f'Loading lora adapter from {base_model}') | |
| lora_config = LoraConfig( | |
| r=model_args.lora_r, | |
| lora_alpha=model_args.lora_alpha, | |
| target_modules=model_args.lora_target_modules.split(','), | |
| lora_dropout=model_args.lora_dropout, | |
| init_lora_weights="gaussian", | |
| use_dora=True, | |
| inference_mode=False | |
| ) | |
| lora_model = get_peft_model(base_model, lora_config) | |
| model = cls( | |
| encoder=lora_model, | |
| pooling=model_args.pooling, | |
| normalize=model_args.normalize, | |
| temperature=model_args.temperature | |
| ) | |
| else: | |
| model = cls( | |
| encoder=base_model, | |
| pooling=model_args.pooling, | |
| normalize=model_args.normalize, | |
| temperature=model_args.temperature | |
| ) | |
| return model | |
| def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs): | |
| # Loading the base model | |
| model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name | |
| config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) | |
| if not hasattr(model_args, "model_backbone") or not model_args.model_backbone: | |
| model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type) | |
| setattr(model_args, 'model_backbone', model_backbone) | |
| print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}') | |
| base_model = None # <-- ensure defined before branches | |
| if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}: | |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| config._attn_implementation = "flash_attention_2" | |
| if hasattr(config, "vision_config") and config.vision_config is not None: | |
| config.vision_config._attn_implementation = "flash_attention_2" | |
| mode = getattr(model_args, "vision_compression", "token_pooling") | |
| # ========= Qwen2-VL ========= | |
| if model_args.model_backbone == QWEN2_VL: | |
| if mode == "token_pooling": | |
| BaseCls = Qwen2VLForConditionalGenerationTokenPooling | |
| print_master("[VisionCompression-load] Qwen2-VL using TokenPooling modeling") | |
| elif mode == "visionzip": | |
| BaseCls = Qwen2VLForConditionalGenerationVisionZip | |
| print_master("[VisionCompression-load] Qwen2-VL using VisionZip modeling") | |
| else: | |
| BaseCls = backbone2model[model_args.model_backbone] | |
| print_master(f"[VisionCompression-load] Qwen2-VL using vanilla backbone (mode={mode})") | |
| # ========= Qwen2.5-VL ========= | |
| elif model_args.model_backbone == QWEN2_5_VL: | |
| if mode == "token_pooling": | |
| BaseCls = Qwen2_5VLForConditionalGenerationTokenPooling | |
| print_master("[VisionCompression-load] Qwen2.5-VL using TokenPooling modeling") | |
| elif mode == "visionzip": | |
| BaseCls = Qwen2_5VLForConditionalGenerationVisionZip | |
| print_master("[VisionCompression-load] Qwen2.5-VL using VisionZip modeling") | |
| else: | |
| BaseCls = backbone2model[model_args.model_backbone] | |
| print_master(f"[VisionCompression-load] Qwen2.5-VL using vanilla backbone (mode={mode})") | |
| # 其它 backbone 走原来的 mapping | |
| else: | |
| BaseCls = backbone2model[model_args.model_backbone] | |
| base_model = BaseCls.from_pretrained( | |
| model_args.model_name, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| config=config, | |
| ) | |
| elif model_args.model_backbone == PHI3V: | |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| config.use_cache = False | |
| config.padding_side = "right" | |
| base_model = Phi3VForCausalLM.from_pretrained( | |
| model_args.model_name, **kwargs, config=config, | |
| torch_dtype=torch.bfloat16, trust_remote_code=True | |
| ) | |
| base_model.padding_side = "right" | |
| elif model_args.model_backbone == INTERNVIDEO2: | |
| print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}') | |
| config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", trust_remote_code=True) | |
| base_model = backbone2model[model_args.model_backbone].from_pretrained( | |
| "src/model/vlm_backbone/internvideo2/", config=config, trust_remote_code=True | |
| ) | |
| elif model_args.model_backbone == GME: | |
| base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor']) | |
| setattr(base_model, 'config', config) | |
| elif model_args.model_backbone == LamRA: | |
| base_model = LamRAQwen2VL(model_args.model_name) | |
| setattr(base_model, 'config', config) | |
| elif model_args.model_backbone == LamRA_QWEN2_5: | |
| base_model = LamRAQwen25VL(model_args.model_name) | |
| setattr(base_model, 'config', config) | |
| elif model_args.model_backbone == COLPALI: | |
| base_model = ColPali.from_pretrained(model_args.model_name) | |
| setattr(base_model, 'config', config) | |
| else: | |
| # Loading external base model from HF | |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| config.use_cache = False | |
| base_model = cls.TRANSFORMER_CLS.from_pretrained( | |
| model_name_or_path, **kwargs, config=config, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True | |
| ) | |
| # <-- call after base_model is assigned | |
| _ensure_pad_token_id_on_model(base_model) | |
| # Building the model on top of the base | |
| if model_args.lora: | |
| print_master(f'Loading LoRA from {model_name_or_path}') | |
| lora_config = LoraConfig.from_pretrained(model_name_or_path) | |
| lora_model = PeftModel.from_pretrained( | |
| base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable | |
| ) | |
| lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable) | |
| if not is_trainable: | |
| lora_model = lora_model.merge_and_unload() | |
| model = cls( | |
| encoder=lora_model, | |
| pooling=model_args.pooling, | |
| normalize=model_args.normalize, | |
| temperature=model_args.temperature | |
| ) | |
| else: | |
| model = cls( | |
| encoder=base_model, | |
| pooling=model_args.pooling, | |
| normalize=model_args.normalize, | |
| temperature=model_args.temperature | |
| ) | |
| model.model_backbone = model_args.model_backbone | |
| return model | |
| def save(self, output_dir: str): | |
| self.encoder.save_pretrained(output_dir) | |
| def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs): | |
| qry_reps = self.encode_input(qry) if qry else None # (bsz_per_device, dim) | |
| tgt_reps = self.encode_input(tgt) if tgt else None # (bsz_per_device, dim) | |
| if qry_reps is None or tgt_reps is None: | |
| return {"qry_reps": qry_reps, "tgt_reps": tgt_reps} | |
| if self.is_ddp: | |
| all_qry_reps = self._dist_gather_tensor(qry_reps) | |
| all_tgt_reps = self._dist_gather_tensor(tgt_reps) | |
| else: | |
| all_qry_reps = qry_reps | |
| all_tgt_reps = tgt_reps | |
| scores = self.compute_similarity(all_qry_reps, all_tgt_reps) | |
| scores = scores.view(all_qry_reps.size(0), -1) | |
| target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) | |
| target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0)) | |
| loss = self.cross_entropy(scores / self.temperature, target) | |
| if self.is_ddp: | |
| loss = loss * self.world_size | |
| return loss | |
| def _dist_gather_tensor(self, t: Tensor): | |
| t = t.contiguous() | |
| all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] | |
| dist.all_gather(all_tensors, t) | |
| all_tensors[self.process_rank] = t | |
| all_tensors = torch.cat(all_tensors, dim=0) | |
| return all_tensors | |
| def compute_similarity(self, q_reps, p_reps): | |
| return torch.matmul(q_reps, p_reps.transpose(0, 1)) |