from typing import Dict import os import torch import torch.distributed as dist from torch import nn, Tensor import torch.nn.functional as F # 如果文件顶部没引入的话 from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig from peft import LoraConfig, get_peft_model, PeftModel from src.model.processor import QWEN2_5_VL_TOKENSELECTION from src.arguments_multi_layer import ModelArguments, TrainingArguments from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \ backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \ QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI from src.model.baseline_backbone.colpali import ColPali from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration from transformers import modeling_utils if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise'] from contextlib import contextmanager class _AOPSwitch: """ Temporarily toggle encoder.aop_prune_config.enabled for one forward call. """ def __init__(self, module: nn.Module, enable: bool): self.module = module self.enable = bool(enable) self._old = getattr(module, "aop_prune_config", None) def __enter__(self): # if no config set, nothing to do if self._old is None: return self if not self.enable: # disable only for this scope if isinstance(self._old, dict): cfg = dict(self._old) cfg["enabled"] = False setattr(self.module, "aop_prune_config", cfg) else: setattr(self.module, "aop_prune_config", None) # if enable=True, keep as is return self def __exit__(self, exc_type, exc, tb): # restore original setattr(self.module, "aop_prune_config", self._old) return False class MMEBModel(nn.Module): TRANSFORMER_CLS = AutoModelForCausalLM def __init__(self, encoder: PreTrainedModel, pooling: str = 'last', normalize: bool = False, temperature: float = 0.02, ): super().__init__() self.config = encoder.config self.encoder = encoder self.pooling = pooling self.normalize = normalize self.temperature = temperature self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') self.is_ddp = dist.is_initialized() if self.is_ddp: self.process_rank = dist.get_rank() self.world_size = dist.get_world_size() self.layer_indices = [20, -1] # 默认:一个或多个中间层 + 最后一层 self.supervise_layers = [20, -1] # -1 必须表示最后一层 self.supervise_weights = [0.15, 0.85] # 与 supervise_layers 对齐 @property def device(self) -> torch.device: try: return next(self.parameters()).device except StopIteration: # 没有参数时兜底 CPU return torch.device("cpu") def _want_prune_for(self, side: str) -> bool: """ side: "qry" or "tgt" """ cfg = getattr(self.encoder, "aop_prune_config", None) if not isinstance(cfg, dict) or not cfg.get("enabled", False): return False apply_to = str(cfg.get("apply_to", "both")).lower() return (apply_to == "both") or (apply_to == side.lower()) def _normalize_layers(self, hs_len: int, layers: list[int]) -> list[int]: Lmax = hs_len - 1 out = [] for idx in layers: if idx < 0: idx = hs_len + idx idx = max(1, min(idx, Lmax)) out.append(idx) if (hs_len - 1) not in out: out.append(hs_len - 1) return out def _encode_multi(self, input): """ 通用多层编码:返回 [B, K, D],K=len(self.supervise_layers,经规范化且包含最后一层)。 """ mb = getattr(self, "model_backbone", None) def norm(x): return F.normalize(x, p=2, dim=-1) if self.normalize else x # 支持 hidden_states 的通用分支(Qwen2-VL 等) if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI]: out = self.encoder(**input, return_dict=True, output_hidden_states=True) hs_list = out.hidden_states # tuple/list, len = num_layers + 1(进入每层前的快照 + 最后norm) # 剪裁后 attention_mask(若未剪裁则为 None) post_mask = getattr(out, "attention_mask", None) # [B, L_post] or None pre_mask = input['attention_mask'] # [B, L_pre] # 规范化 supervise_layers,并确保包含最后一层 idxs = self._normalize_layers(len(hs_list), list(dict.fromkeys(self.supervise_layers))) # 读取训练时的剪裁层(1-based)。剪裁发生在进入 cut_layer 前,所以 idx >= cut_layer+1 才会看到 post 形状 aop_cfg = getattr(self.encoder, "aop_prune_config", None) cut_layer = None if isinstance(aop_cfg, dict) and aop_cfg.get("enabled", False): try: cut_layer = int(aop_cfg.get("layer_idx") or 0) if cut_layer <= 0: cut_layer = None except Exception: cut_layer = None reps = [] for idx in idxs: # 选择该层使用的 mask use_post = (post_mask is not None) and (cut_layer is not None) and (idx >= cut_layer + 1) mask_this = post_mask if use_post else pre_mask h = hs_list[idx] # [B, L_idx, D] # 友好断言:若 mask 与 h 长度不一致,尝试回退到另一个 mask;否则兜底全1 if mask_this is not None and h.size(1) != mask_this.size(1): if pre_mask is not None and pre_mask.size(1) == h.size(1): mask_this = pre_mask elif post_mask is not None and post_mask.size(1) == h.size(1): mask_this = post_mask else: mask_this = torch.ones(h.size(0), h.size(1), dtype=torch.long, device=h.device) r = self._pooling(h, mask_this) reps.append(F.normalize(r, p=2, dim=-1) if self.normalize else r) return torch.stack(reps, dim=1) # [B, K, D] # def encode_input(self, input): def encode_input(self, input, layer_indices=None): if getattr(self, "model_backbone", None) == INTERNVIDEO2: if "input_ids" in input.keys(): # text side text_output = self.encoder.get_text_encoder()( input["input_ids"], attention_mask=input["attention_mask"], return_dict=True, mode="text", ) text_embeds = text_output.last_hidden_state pooled_text_embeds = text_embeds[:, 0] pooled_output = self.encoder.text_proj(pooled_text_embeds) pooled_output /= pooled_output.norm(dim=-1, keepdim=True) return pooled_output else: _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) vfeat = self.encoder.vision_proj(vfeat) vfeat /= vfeat.norm(dim=-1, keepdim=True) return vfeat elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]: # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen images = [] for imgs in input['images']: # if multi images are given, select the middle frame only if isinstance(imgs, list): imgs = imgs[len(imgs) // 2] assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list images.append(imgs) else: images.append(imgs) pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) return pooled_output elif getattr(self, "model_backbone", None) == COLPALI: pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) return pooled_output elif getattr(self, "model_backbone", None) == LLAVA_NEXT: input['pixel_values'] = input['pixel_values'].squeeze(dim=1) input['image_sizes'] = input['image_sizes'].squeeze(dim=1) hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) hidden_states = hidden_states.hidden_states[-1] pooled_output = self._pooling(hidden_states, input['attention_mask']) return pooled_output else: # 默认HF模型:支持 hidden_states(含 AOP 剪裁) out = self.encoder(**input, return_dict=True, output_hidden_states=True) hs_list = out.hidden_states post_mask = getattr(out, "attention_mask", None) # [B, L_post] or None pre_mask = input['attention_mask'] # [B, L_pre] # === AOP_MONITOR:观测每个样本的剪枝前/后长度与有效保留率 === if os.getenv("AOP_MONITOR", "0") == "1": try: B = pre_mask.size(0) if pre_mask is not None else hs_list[-1].size(0) # 全局长度 pre_len = pre_mask.sum(dim=1).detach().cpu().tolist() if pre_mask is not None else [hs_list[-1].size(1)] * B post_len = post_mask.sum(dim=1).detach().cpu().tolist() if post_mask is not None else pre_len # 最近一次采样到的 keep_ratio(trainer 写入 cfg)仅用于参考打印 aop_cfg = getattr(self.encoder, "aop_prune_config", None) kr_t = aop_cfg.get("_last_sampled_keep_ratio_text") if isinstance(aop_cfg, dict) else None kr_v = aop_cfg.get("_last_sampled_keep_ratio_vision") if isinstance(aop_cfg, dict) else None # 文本/视觉细分(可选) pre_txt_cnt = pre_vis_cnt = post_txt_cnt = post_vis_cnt = None input_ids = input.get("input_ids", None) if input_ids is not None and pre_mask is not None: cfg = self.encoder.config valid_pre = pre_mask.bool() vis_pre = (input_ids == getattr(cfg, "image_token_id", -999)) if hasattr(cfg, "video_token_id") and cfg.video_token_id is not None and cfg.video_token_id >= 0: vis_pre = vis_pre | (input_ids == cfg.video_token_id) special_pre = torch.zeros_like(input_ids, dtype=torch.bool) for name in ["bos_token_id", "eos_token_id", "pad_token_id"]: tid = getattr(cfg, name, None) if tid is not None and tid >= 0: special_pre |= (input_ids == tid) pre_vis_cnt = (vis_pre & valid_pre).sum(dim=1).detach().cpu().tolist() pre_txt_cnt = (valid_pre & (~vis_pre) & (~special_pre)).sum(dim=1).detach().cpu().tolist() vis_post_mask = getattr(out, "image_token_bool_masks", None) txt_post_mask = getattr(out, "text_token_bool_masks", None) if vis_post_mask is not None: post_vis_cnt = vis_post_mask.sum(dim=1).detach().cpu().tolist() if txt_post_mask is not None: post_txt_cnt = txt_post_mask.sum(dim=1).detach().cpu().tolist() # 限制打印批次数,避免刷屏 if not hasattr(self, "_aop_mon_prints"): self._aop_mon_prints = 0 if self._aop_mon_prints < 3: # 仅前3个batch打印 print(f"[AOP][monitor] B={B} sampled: kr_text={kr_t}, kr_vision={kr_v}") for b in range(min(B, 8)): # 仅打印前8条样本 preL = int(pre_len[b]); postL = int(post_len[b]); keep = (postL / (preL + 1e-9)) msg = f" b={b}: pre_len={preL}, post_len={postL}, keep={keep:.3f}" if pre_txt_cnt is not None and post_txt_cnt is not None: kt = (post_txt_cnt[b] / (pre_txt_cnt[b] + 1e-9)) if pre_txt_cnt[b] > 0 else float('nan') msg += f", txt_keep={kt:.3f}" if pre_vis_cnt is not None and post_vis_cnt is not None: kv = (post_vis_cnt[b] / (pre_vis_cnt[b] + 1e-9)) if pre_vis_cnt[b] > 0 else float('nan') msg += f", vis_keep={kv:.3f}" print(msg) self._aop_mon_prints += 1 except Exception as e: # 避免影响训练流程 print(f"[AOP][monitor] warn: monitor failed with error: {e}") def _pooling(self, last_hidden_state, attention_mask): if self.pooling == 'last' or self.pooling == 'eos': left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) batch_size = last_hidden_state.shape[0] if left_padding: # Get the vectors at the last position reps = last_hidden_state[torch.arange(batch_size), -1, :] else: # Calculate last 1 position in the original tensor eos_indices = attention_mask.sum(dim=1) - 1 # Get the vectors at the last 1 position of each attention mask reps = last_hidden_state[ torch.arange(batch_size, device=last_hidden_state.device), eos_indices] else: raise NotImplementedError if self.normalize: reps = torch.nn.functional.normalize(reps, p=2, dim=-1) return reps @classmethod def build(cls, model_args: ModelArguments, **kwargs): config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) variant = getattr(config, "backbone_variant", None) if variant == "layerprune": model_backbone = "QWEN2_VL_LayerPrune" else: model_backbone = get_backbone_name(hf_config=config) print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}') # Loading the base model if model_backbone == PHI3V: config._attn_implementation = "eager" config.padding_side = "right" config.use_cache = False base_model = Phi3VForCausalLM.from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone == LLAVA_NEXT: config.use_cache = False config.padding_side = "left" base_model = LlavaNextForConditionalGeneration.from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone in [QWEN2_VL, QWEN2_5_VL]: config._attn_implementation = "flash_attention_2" config.padding_side = "left" config.use_cache = False base_model = backbone2model[model_backbone].from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone in ["QWEN2_VL_LayerPrune"]: config._attn_implementation = "flash_attention_2" config.padding_side = "left" config.use_cache = False base_model = backbone2model[model_backbone].from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]: config._attn_implementation = "flash_attention_2" config.padding_side = "left" config.use_cache = False from .utils import parse_layer_type lm_qwen_layer = 28 vis_qwen_layer = 32 lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer) vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer) base_model = backbone2model[model_backbone].from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, lm_skip_layer=lm_skip_layer, vis_skip_layer=vis_skip_layer, ) else: config.use_cache = False base_model = cls.TRANSFORMER_CLS.from_pretrained( model_args.model_name, **kwargs, config=config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, trust_remote_code=True) if model_args.lora: print_master(f'Loading lora adapter from {base_model}') lora_config = LoraConfig( r=model_args.lora_r, lora_alpha=model_args.lora_alpha, target_modules=model_args.lora_target_modules.split(','), lora_dropout=model_args.lora_dropout, init_lora_weights="gaussian", use_dora=True, inference_mode=False ) lora_model = get_peft_model(base_model, lora_config) model = cls( encoder=lora_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) else: model = cls( encoder=base_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) # 在 build(...) 末尾(return model 前)添加 def _parse_list(val, tp=float): if val is None: return None if isinstance(val, (list, tuple)): return [tp(x) for x in val] s = str(val).strip() if s == "": return None return [tp(v.strip()) for v in s.split(",") if v.strip() != ""] layers = _parse_list(getattr(model_args, "supervise_layers", None), tp=int) weights = _parse_list(getattr(model_args, "supervise_weights", None), tp=float) if layers is None: # fallback 到旧的二层设置 layers = [getattr(model_args, 'dual_layer_idx', 20), -1] if -1 not in layers: layers = list(layers) + [-1] # 强制包含最后一层 if weights is None or len(weights) != len(layers): # 若未提供或长度不匹配,则做一个合理默认:最后一层占大头 K = len(layers) base = [1.0/(K-1)]*(K-1) if K>1 else [1.0] weights = base + [max(0.0, 1.0 - sum(base))] # 归一化 s = sum(max(0.0, w) for w in weights) weights = [max(0.0, w)/s for w in weights] setattr(model, 'supervise_layers', layers) setattr(model, 'supervise_weights', weights) # 兼容旧参数 setattr(model, 'dual_layer_idx', layers[0] if len(layers)>1 else layers[0]) setattr(model, 'dual_alpha', weights[0] if len(weights)>1 else 1.0) setattr(model, 'layer_indices', layers) return model @classmethod def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs): # Loading the base model model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) if not hasattr(model_args, "model_backbone") or not model_args.model_backbone: model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type) setattr(model_args, 'model_backbone', model_backbone) print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}') if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}: config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config._attn_implementation = "flash_attention_2" config.vision_config._attn_implementation = "flash_attention_2" base_model = backbone2model[model_args.model_backbone].from_pretrained( model_args.model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, config=config ) elif model_args.model_backbone == PHI3V: config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False config.padding_side = "right" base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True) base_model.padding_side = "right" elif model_args.model_backbone == INTERNVIDEO2: print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}') config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", trust_remote_code=True) base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config, trust_remote_code=True) elif model_args.model_backbone == GME: base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor']) setattr(base_model, 'config', config) elif model_args.model_backbone == LamRA: base_model = LamRAQwen2VL(model_args.model_name) setattr(base_model, 'config', config) elif model_args.model_backbone == LamRA_QWEN2_5: base_model = LamRAQwen25VL(model_args.model_name) setattr(base_model, 'config', config) elif model_args.model_backbone == COLPALI: base_model = ColPali.from_pretrained(model_args.model_name) setattr(base_model, 'config', config) else: # Loading external base model from HF config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False base_model = cls.TRANSFORMER_CLS.from_pretrained( model_name_or_path, **kwargs, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True) # Building the model on top of the base if model_args.lora: print_master(f'Loading LoRA from {model_name_or_path}') lora_config = LoraConfig.from_pretrained(model_name_or_path) lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable) lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable) if not is_trainable: lora_model = lora_model.merge_and_unload() model = cls( encoder=lora_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) else: model = cls( encoder=base_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) model.model_backbone = model_args.model_backbone return model def save(self, output_dir: str): self.encoder.save_pretrained(output_dir) def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs): # GradCache:只给一侧 -> 返回多层表示 if qry is not None and tgt is None: with _AOPSwitch(self.encoder, self._want_prune_for("qry")): qry_reps = self._encode_multi(qry) # [B, K, D] return {"qry_reps": qry_reps, "tgt_reps": None} if tgt is not None and qry is None: with _AOPSwitch(self.encoder, self._want_prune_for("tgt")): tgt_reps = self._encode_multi(tgt) # [B, K, D] return {"qry_reps": None, "tgt_reps": tgt_reps} with _AOPSwitch(self.encoder, self._want_prune_for("qry")): q_multi = self._encode_multi(qry) # [B, K, D] with _AOPSwitch(self.encoder, self._want_prune_for("tgt")): p_multi = self._encode_multi(tgt) # [B, K, D] # DDP gather if self.is_ddp: q_multi_all = self._dist_gather_tensor(q_multi) # [B*, K, D] p_multi_all = self._dist_gather_tensor(p_multi) # [B*, K, D] else: q_multi_all, p_multi_all = q_multi, p_multi Bglob, K, D = q_multi_all.shape assert p_multi_all.shape[:2] == (Bglob, K), f"Shape mismatch: q {q_multi_all.shape}, p {p_multi_all.shape}" target = torch.arange(Bglob, device=q_multi_all.device, dtype=torch.long) w = torch.tensor(self.supervise_weights, dtype=torch.float32, device=q_multi_all.device) w = torch.clamp(w, min=0) w = w / max(w.sum().item(), 1e-8) loss = 0.0 for k in range(K): # 逐层配对(k ↔ k) logits_k = torch.matmul(q_multi_all[:, k, :], p_multi_all[:, k, :].transpose(0, 1)) / self.temperature loss_k = self.cross_entropy(logits_k, target) loss = loss + w[k] * loss_k if self.is_ddp: loss = loss * self.world_size return loss def _dist_gather_tensor(self, t: Tensor): t = t.contiguous() all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] dist.all_gather(all_tensors, t) all_tensors[self.process_rank] = t all_tensors = torch.cat(all_tensors, dim=0) return all_tensors def compute_similarity(self, q_reps, p_reps): return torch.matmul(q_reps, p_reps.transpose(0, 1))