MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
from typing import Dict
import torch
import torch.distributed as dist
from torch import nn, Tensor
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig
from src.model.vlm_backbone.qwen2_5_vl_gp.configuration import Qwen2_5_VL_GPConfig
from src.model.vlm_backbone.qwen2_vl_gp.configuration import Qwen2VL_GPConfig
from peft import LoraConfig, get_peft_model, PeftModel
from src.model.processor import QWEN2_5_VL_TOKENSELECTION
from src.arguments_gp import ModelArguments, TrainingArguments
from src.model.processor import LLAVA_NEXT, QWEN2_VL, QWEN2_VL_GP, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \
backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V
from src.arguments_gp import ModelArguments
from src.model.processor import LLAVA_NEXT, QWEN2_VL, QWEN2_VL_GP, QWEN2_5_VL_GP, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \
QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI
from src.model.baseline_backbone.colpali import ColPali
from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL
from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL
from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL
from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM
from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration
from transformers import modeling_utils
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise']
# ADD mask output
class MMEBModel(nn.Module):
TRANSFORMER_CLS = AutoModelForCausalLM
def __init__(self,
encoder: PreTrainedModel,
pooling: str = 'last',
normalize: bool = False,
temperature: float = 0.02,
):
super().__init__()
self.config = encoder.config
self.encoder = encoder
self.pooling = pooling
self.normalize = normalize
self.temperature = temperature
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.is_ddp = dist.is_initialized()
if self.is_ddp:
self.process_rank = dist.get_rank()
self.world_size = dist.get_world_size()
self._last_image_token_bool_masks = None
self._last_image_token_mask_logits = None
self._last_le_loss = None
def encode_input(self, input):
self._last_image_token_bool_masks = None
if getattr(self, "model_backbone", None) == INTERNVIDEO2:
if "input_ids" in input.keys():
# text side
text_output = self.encoder.get_text_encoder()(
input["input_ids"],
attention_mask=input["attention_mask"],
return_dict=True,
mode="text",
)
text_embeds = text_output.last_hidden_state
pooled_text_embeds = text_embeds[:, 0]
pooled_output = self.encoder.text_proj(pooled_text_embeds)
pooled_output /= pooled_output.norm(dim=-1, keepdim=True)
return pooled_output
else:
_, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True)
vfeat = self.encoder.vision_proj(vfeat)
vfeat /= vfeat.norm(dim=-1, keepdim=True)
return vfeat
elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]:
# pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen
images = []
for imgs in input['images']:
# if multi images are given, select the middle frame only
if isinstance(imgs, list):
imgs = imgs[len(imgs) // 2]
assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list
images.append(imgs)
else:
images.append(imgs)
pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images)
return pooled_output
elif getattr(self, "model_backbone", None) == COLPALI:
pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
return pooled_output
elif getattr(self, "model_backbone", None) == LLAVA_NEXT:
input['pixel_values'] = input['pixel_values'].squeeze(dim=1)
input['image_sizes'] = input['image_sizes'].squeeze(dim=1)
hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True)
hidden_states = hidden_states.hidden_states[-1]
pooled_output = self._pooling(hidden_states, input['attention_mask'])
return pooled_output
elif getattr(self, "model_backbone", None) in {QWEN2_VL_GP, QWEN2_5_VL_GP}:
# do_sel = self.gp_do_selection
# print('do_sel:', do_sel)
# exit()
outputs = self.encoder(
**input,
return_dict=True,
output_hidden_states=True, # 如需可设 False
do_selection=True,
delay_selection=False, # if do_sel else True, # 开裁剪时立刻裁剪;训练默认不裁剪
use_cache=False,
use_ref_masks=False, # 推理也用预测掩码
)
self._last_image_token_bool_masks = getattr(outputs, "image_token_bool_masks", None)
self._last_image_token_mask_logits = getattr(outputs, "image_token_mask_logits", None)
self._last_le_loss = getattr(outputs, "le_loss", None)
hs = outputs.hidden_states
last_hidden = hs[-1] if isinstance(hs, (list, tuple)) else hs
attn_mask = getattr(outputs, "attention_mask", None)
if attn_mask is None:
attn_mask = input["attention_mask"]
attn = attn_mask
pooled_output = self._pooling(last_hidden, attn)
return pooled_output
else:
# 默认分支同样要鲁棒处理(有可能用到带 glimpse 的模型)
outputs = self.encoder(**input, return_dict=True, output_hidden_states=True)
img_masks = getattr(outputs, "image_token_bool_masks", None)
if img_masks is None:
try:
img_masks = outputs.get("image_token_bool_masks", None)
except Exception:
pass
self._last_image_token_bool_masks = img_masks
hs = outputs.hidden_states
if isinstance(hs, (list, tuple)):
last_hidden = hs[-1] # [B, L, D]
else:
last_hidden = hs # [B, L, D](如果底模走了 glimpse 路径)
attn = getattr(outputs, "attention_mask", None)
if attn is None:
attn = input["attention_mask"]
pooled_output = self._pooling(last_hidden, attn)
return pooled_output
def _pooling(self, last_hidden_state, attention_mask):
if self.pooling == 'last' or self.pooling == 'eos':
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
batch_size = last_hidden_state.shape[0]
if left_padding:
# Get the vectors at the last position
reps = last_hidden_state[torch.arange(batch_size), -1, :]
else:
# Calculate last 1 position in the original tensor
eos_indices = attention_mask.sum(dim=1) - 1
# Get the vectors at the last 1 position of each attention mask
reps = last_hidden_state[
torch.arange(batch_size, device=last_hidden_state.device), eos_indices]
else:
raise NotImplementedError
if self.normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
@classmethod
def build(cls, model_args: ModelArguments, **kwargs):
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
model_backbone = get_backbone_name(hf_config=config)
print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}')
# Loading the base model
if model_backbone == PHI3V:
config._attn_implementation = "eager"
config.padding_side = "right"
config.use_cache = False
base_model = Phi3VForCausalLM.from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone == LLAVA_NEXT:
config.use_cache = False
config.padding_side = "left"
base_model = LlavaNextForConditionalGeneration.from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in [QWEN2_VL, QWEN2_5_VL]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in [QWEN2_VL_GP, QWEN2_5_VL_GP]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
from .utils import parse_layer_type
lm_qwen_layer = 28
vis_qwen_layer = 32
lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer)
vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer)
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
lm_skip_layer=lm_skip_layer,
vis_skip_layer=vis_skip_layer,
)
else:
config.use_cache = False
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_args.model_name, **kwargs, config=config,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
trust_remote_code=True)
if model_args.lora:
print_master(f'Loading lora adapter from {base_model}')
lora_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
target_modules=model_args.lora_target_modules.split(','),
lora_dropout=model_args.lora_dropout,
init_lora_weights="gaussian",
use_dora=True,
inference_mode=False
)
lora_model = get_peft_model(base_model, lora_config)
model = cls(
encoder=lora_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
else:
model = cls(
encoder=base_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
# model.gp_do_selection = getattr(model_args, "gp_do_selection", False)
return model
@classmethod
def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs):
# Loading the base model
model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
if not hasattr(model_args, "model_backbone") or not model_args.model_backbone:
model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
# print('model_args.model_backbone', model_args.model_backbone)
# exit()
print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}')
if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}:
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config._attn_implementation = "flash_attention_2"
config.vision_config._attn_implementation = "flash_attention_2"
base_model = backbone2model[model_args.model_backbone].from_pretrained(
model_args.model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
config=config
)
elif model_args.model_backbone in {QWEN2_5_VL_GP}:
new_modules_dir = getattr(model_args, "new_modules_dir", None) or kwargs.get("new_modules_dir", None)
if new_modules_dir:
import os
from huggingface_hub import hf_hub_download
def _load_gp_config_from_dir_or_repo(repo_or_dir: str) -> Qwen2_5_VL_GPConfig:
# 支持本地目录或 HF 仓库 ID
if os.path.isdir(repo_or_dir):
cfg_path = os.path.join(repo_or_dir, "config.json")
if not os.path.exists(cfg_path):
raise FileNotFoundError(f"config.json not found in {repo_or_dir}")
else:
# 从 HF 仓库下载 config.json
cfg_path = hf_hub_download(repo_id=repo_or_dir, filename="config.json")
return Qwen2_5_VL_GPConfig.from_json_file(cfg_path)
gp_cfg = _load_gp_config_from_dir_or_repo(new_modules_dir)
gp_cfg._attn_implementation = "flash_attention_2"
gp_cfg.padding_side = "left"
gp_cfg.use_cache = False
try:
gp_cfg.vision_config._attn_implementation = "flash_attention_2"
except Exception:
if isinstance(gp_cfg.vision_config, dict):
gp_cfg.vision_config["_attn_implementation"] = "flash_attention_2"
base_model = backbone2model[model_args.model_backbone].from_pretrained(
model_args.model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
config=gp_cfg,
)
# print('base_model:', base_model)
# exit()
base_model.load_new_modules(new_modules_dir)
try:
print_master(f"[GP] attn_fuser = {type(base_model.attn_fuser).__name__}")
except Exception:
print_master("[GP] attn_fuser not found (unexpected).")
else:
# 没有 new_modules_dir:保留你原来的 fallback(用 base config 构建 GPConfig,不加载新模块权重)
base_cfg = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
gp_cfg = Qwen2_5_VL_GPConfig(**base_cfg.to_dict())
gp_cfg._attn_implementation = "flash_attention_2"
try:
gp_cfg.vision_config._attn_implementation = "flash_attention_2"
except Exception:
if isinstance(gp_cfg.vision_config, dict):
gp_cfg.vision_config["_attn_implementation"] = "flash_attention_2"
gp_cfg.padding_side = "left"
gp_cfg.use_cache = False
L = getattr(base_cfg, "num_hidden_layers", None) or 28
gp_cfg.reduce_layer = min(max(L - 2, 1), L - 1)
gp_cfg.selected_layers = [max(gp_cfg.reduce_layer - 2, 0), gp_cfg.reduce_layer]
# gp_cfg.attn_fuse_type = "AttnFuserV1"
# gp_cfg.selected_visual_layers = (8,)
# gp_cfg.use_attention_logits = False
# gp_cfg.use_zero_masks = False
# gp_cfg.use_ref_masks = False
# gp_cfg.le_layers = []
# gp_cfg.le_length = 0
base_model = backbone2model[model_args.model_backbone].from_pretrained(
model_args.model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
config=gp_cfg,
)
elif model_args.model_backbone == QWEN2_VL_GP:
import os
from huggingface_hub import snapshot_download
# from transformers import AutoConfig
from transformers.utils import is_flash_attn_2_available
new_modules_dir = getattr(model_args, "new_modules_dir", None) or kwargs.get("new_modules_dir", None)
if not new_modules_dir:
raise ValueError("new_modules_dir is required for QWEN2_VL_GP")
# 统一成本地目录
if os.path.isdir(new_modules_dir):
local_new_modules_dir = new_modules_dir
else:
local_new_modules_dir = snapshot_download(repo_id=new_modules_dir)
# 读取 GP 模块自带的配置(只用于校验,不拿它来构建 base)
gp_cfg_path = os.path.join(local_new_modules_dir, "config.json")
if not os.path.exists(gp_cfg_path):
raise FileNotFoundError(f"config.json not found in {local_new_modules_dir}")
gp_cfg = Qwen2VL_GPConfig.from_json_file(gp_cfg_path)
# 用 base 模型自己的 config 来构建,避免与 checkpoint 维度不一致
base_cfg = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# 早检测:关键结构参数必须一致(尤其 hidden_size)
if hasattr(gp_cfg, "hidden_size") and hasattr(base_cfg, "hidden_size"):
if gp_cfg.hidden_size != base_cfg.hidden_size:
raise ValueError(
f"[Mismatch] base_model hidden_size={base_cfg.hidden_size} "
f"!= gp_modules hidden_size={gp_cfg.hidden_size}. "
f"请更换为匹配的 base 模型或对应的 new_modules_dir。"
)
# 只做运行时补丁,不改动结构
base_cfg._attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
base_cfg.padding_side = "left"
base_cfg.use_cache = False
if hasattr(base_cfg, "vision_config"):
vc = base_cfg.vision_config
if isinstance(vc, dict):
vc["_attn_implementation"] = base_cfg._attn_implementation
else:
setattr(vc, "_attn_implementation", base_cfg._attn_implementation)
# 构建并加载 base 权重(与 checkpoint 结构一致)
ModelClass = backbone2model[model_args.model_backbone]
base_model = ModelClass.from_pretrained(
model_args.model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
config=base_cfg,
trust_remote_code=True,
)
# 注入 GP 模块(目录必须与 base 匹配)
base_model.load_new_modules(local_new_modules_dir)
try:
print_master(f"[GP] attn_fuser = {type(base_model.attn_fuser).__name__}")
except Exception:
print_master("[GP] attn_fuser not found (unexpected).")
elif model_args.model_backbone == PHI3V:
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config.use_cache = False
config.padding_side = "right"
base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config,
torch_dtype=torch.bfloat16, trust_remote_code=True)
base_model.padding_side = "right"
elif model_args.model_backbone == INTERNVIDEO2:
print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}')
config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/",
trust_remote_code=True)
base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config,
trust_remote_code=True)
elif model_args.model_backbone == GME:
base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor'])
setattr(base_model, 'config', config)
elif model_args.model_backbone == LamRA:
base_model = LamRAQwen2VL(model_args.model_name)
setattr(base_model, 'config', config)
elif model_args.model_backbone == LamRA_QWEN2_5:
base_model = LamRAQwen25VL(model_args.model_name)
setattr(base_model, 'config', config)
elif model_args.model_backbone == COLPALI:
base_model = ColPali.from_pretrained(model_args.model_name)
setattr(base_model, 'config', config)
else:
# Loading external base model from HF
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config.use_cache = False
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_name_or_path, **kwargs, config=config,
torch_dtype=torch.bfloat16,
trust_remote_code=True)
# Building the model on top of the base
if model_args.lora:
print_master(f'Loading LoRA from {model_name_or_path}')
lora_config = LoraConfig.from_pretrained(model_name_or_path)
lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable)
lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable)
if not is_trainable:
lora_model = lora_model.merge_and_unload()
model = cls(
encoder=lora_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
else:
model = cls(
encoder=base_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
model.model_backbone = model_args.model_backbone
return model
def save(self, output_dir: str):
self.encoder.save_pretrained(output_dir)
def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs):
qry_reps = self.encode_input(qry) if qry else None # (bsz_per_device, dim)
tgt_reps = self.encode_input(tgt) if tgt else None # (bsz_per_device, dim)
# qry_reps, tgt_reps = None, None
# if qry is not None:
# qry_reps = self.encode_input(qry, compression_rate=0.2) # QRY 压缩率
# if tgt is not None:
# tgt_reps = self.encode_input(tgt, compression_rate=0.4) # CAND 压缩率
if qry_reps is None or tgt_reps is None:
# return {"qry_reps": qry_reps, "tgt_reps": tgt_reps}
out = {"qry_reps": qry_reps, "tgt_reps": tgt_reps}
# NEW: 透传底模的图像 token mask(如果有)
img_masks = getattr(self, "_last_image_token_bool_masks", None)
if img_masks is not None:
out["image_token_bool_masks"] = img_masks
return out
if self.is_ddp:
all_qry_reps = self._dist_gather_tensor(qry_reps)
all_tgt_reps = self._dist_gather_tensor(tgt_reps)
else:
all_qry_reps = qry_reps
all_tgt_reps = tgt_reps
scores = self.compute_similarity(all_qry_reps, all_tgt_reps)
scores = scores.view(all_qry_reps.size(0), -1)
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0))
loss = self.cross_entropy(scores / self.temperature, target)
if self.is_ddp:
loss = loss * self.world_size
return loss
def _dist_gather_tensor(self, t: Tensor):
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)
all_tensors[self.process_rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
def compute_similarity(self, q_reps, p_reps):
return torch.matmul(q_reps, p_reps.transpose(0, 1))
class AsymMMEBModel(nn.Module):
def __init__(self, encoder_qry: PreTrainedModel, encoder_tgt: PreTrainedModel,
pooling: str = 'last', normalize: bool = False, temperature: float = 0.02):
super().__init__()
self.encoder_qry = encoder_qry # Qwen2.5-VL-GP
self.encoder_tgt = encoder_tgt # Qwen2.5-VL
self.pooling = pooling
self.normalize = normalize
self.temperature = temperature
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.is_ddp = dist.is_initialized()
self.world_size = dist.get_world_size() if self.is_ddp else 1
# 仅 query 端控制裁剪
self.gp_do_selection_qry = False # 训练默认 False,推理可 True
# debug/观测用
self._last_image_token_bool_masks = None
self._last_image_token_mask_logits = None
self._last_le_loss = None
def _pool(self, last_hidden_state, attention_mask):
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
B = last_hidden_state.shape[0]
if left_padding:
reps = last_hidden_state[torch.arange(B), -1, :]
else:
idx = attention_mask.sum(dim=1) - 1
reps = last_hidden_state[torch.arange(B, device=last_hidden_state.device), idx]
if self.normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
def encode_qry(self, inp: Dict[str, torch.Tensor]):
# 训练检索对比建议不开裁剪;推理打开
do_sel = self.gp_do_selection_qry
out = self.encoder_qry(
**inp,
return_dict=True,
output_hidden_states=True,
do_selection=do_sel,
delay_selection=False if do_sel else True,
use_cache=False,
use_ref_masks=False,
)
self._last_image_token_bool_masks = getattr(out, "image_token_bool_masks", None)
self._last_image_token_mask_logits = getattr(out, "image_token_mask_logits", None)
self._last_le_loss = getattr(out, "le_loss", None)
hs = out.hidden_states
last = hs[-1] if isinstance(hs, (list, tuple)) else hs
attn = getattr(out, "attention_mask", None) or inp["attention_mask"]
rep = self._pool(last, attn)
return rep
def encode_tgt(self, inp: Dict[str, torch.Tensor]):
out = self.encoder_tgt(
**inp,
return_dict=True,
output_hidden_states=True,
use_cache=False,
)
hs = out.hidden_states
last = hs[-1] if isinstance(hs, (list, tuple)) else hs
attn = getattr(out, "attention_mask", None) or inp["attention_mask"]
rep = self._pool(last, attn)
return rep
def forward(self, qry: Dict[str, torch.Tensor] = None, tgt: Dict[str, torch.Tensor] = None):
qry_reps = self.encode_qry(qry) if qry else None
tgt_reps = self.encode_tgt(tgt) if tgt else None
if qry_reps is None or tgt_reps is None:
return {"qry_reps": qry_reps, "tgt_reps": tgt_reps}
if self.is_ddp:
qry_all = self._gather(qry_reps)
tgt_all = self._gather(tgt_reps)
else:
qry_all, tgt_all = qry_reps, tgt_reps
scores = torch.matmul(qry_all, tgt_all.transpose(0, 1)).view(qry_all.size(0), -1)
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target = target * (qry_all.size(0) // tgt_all.size(0))
loss = self.cross_entropy(scores / self.temperature, target)
if self.is_ddp:
loss = loss * self.world_size
return loss
def _gather(self, t: torch.Tensor):
t = t.contiguous()
outs = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(outs, t)
outs[dist.get_rank()] = t
return torch.cat(outs, dim=0)
@classmethod
def load_asym(cls, model_args, processor=None, processor_tgt=None, is_trainable=True):
# query encoder: GP
from src.model.vlm_backbone.qwen2_5_vl_gp.configuration import Qwen2_5_VL_GPConfig
from src.model.vlm_backbone.qwen2_5_vl_gp.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration_GP as GPModel
from transformers import AutoConfig
new_modules_dir = getattr(model_args, "new_modules_dir", None)
assert new_modules_dir is not None, "new_modules_dir 不能为空(query 端 GP 模块)"
gp_cfg = Qwen2_5_VL_GPConfig.from_pretrained(new_modules_dir)
gp_cfg._attn_implementation = "flash_attention_2"
gp_cfg.padding_side = "left"
gp_cfg.use_cache = False
encoder_qry = GPModel.from_pretrained(
model_args.model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
config=gp_cfg,
)
encoder_qry.load_new_modules(new_modules_dir)
# target encoder: vanilla
from src.model.vlm_backbone.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration as VanillaModel
cfg_tgt = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
cfg_tgt._attn_implementation = "flash_attention_2"
cfg_tgt.use_cache = False
# 兼容视觉
if hasattr(cfg_tgt, "vision_config"):
try:
cfg_tgt.vision_config._attn_implementation = "flash_attention_2"
except Exception:
pass
encoder_tgt = VanillaModel.from_pretrained(
model_args.model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
config=cfg_tgt,
)
# 冻结(如果只训检索头/LoRA)
if getattr(model_args, "image_encoder_freeze", True):
for p in encoder_qry.parameters(): p.requires_grad = False
for p in encoder_tgt.parameters(): p.requires_grad = False
# 只放开 query 端 GP 新模块(如果你要训练选择头)
if hasattr(encoder_qry, "new_modules_to_be_saved"):
for mod in encoder_qry.new_modules_to_be_saved().values():
if isinstance(mod, torch.nn.Parameter):
mod.requires_grad = True
else:
for p in mod.parameters(): p.requires_grad = True
return cls(encoder_qry=encoder_qry, encoder_tgt=encoder_tgt,
pooling=model_args.pooling, normalize=model_args.normalize,
temperature=model_args.temperature)