|
|
from typing import Dict |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch import nn, Tensor |
|
|
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig |
|
|
from src.model.vlm_backbone.qwen2_5_vl_gp.configuration import Qwen2_5_VL_GPConfig |
|
|
from src.model.vlm_backbone.qwen2_vl_gp.configuration import Qwen2VL_GPConfig |
|
|
from peft import LoraConfig, get_peft_model, PeftModel |
|
|
from src.model.processor import QWEN2_5_VL_TOKENSELECTION |
|
|
from src.arguments_gp import ModelArguments, TrainingArguments |
|
|
from src.model.processor import LLAVA_NEXT, QWEN2_VL, QWEN2_VL_GP, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \ |
|
|
backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V |
|
|
|
|
|
from src.arguments_gp import ModelArguments |
|
|
from src.model.processor import LLAVA_NEXT, QWEN2_VL, QWEN2_VL_GP, QWEN2_5_VL_GP, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \ |
|
|
QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI |
|
|
from src.model.baseline_backbone.colpali import ColPali |
|
|
from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL |
|
|
from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL |
|
|
from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL |
|
|
from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM |
|
|
from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration |
|
|
|
|
|
from transformers import modeling_utils |
|
|
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: |
|
|
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise'] |
|
|
|
|
|
|
|
|
class MMEBModel(nn.Module): |
|
|
TRANSFORMER_CLS = AutoModelForCausalLM |
|
|
|
|
|
def __init__(self, |
|
|
encoder: PreTrainedModel, |
|
|
pooling: str = 'last', |
|
|
normalize: bool = False, |
|
|
temperature: float = 0.02, |
|
|
): |
|
|
super().__init__() |
|
|
self.config = encoder.config |
|
|
self.encoder = encoder |
|
|
self.pooling = pooling |
|
|
self.normalize = normalize |
|
|
self.temperature = temperature |
|
|
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') |
|
|
self.is_ddp = dist.is_initialized() |
|
|
if self.is_ddp: |
|
|
self.process_rank = dist.get_rank() |
|
|
self.world_size = dist.get_world_size() |
|
|
|
|
|
self._last_image_token_bool_masks = None |
|
|
self._last_image_token_mask_logits = None |
|
|
self._last_le_loss = None |
|
|
|
|
|
def encode_input(self, input): |
|
|
self._last_image_token_bool_masks = None |
|
|
if getattr(self, "model_backbone", None) == INTERNVIDEO2: |
|
|
if "input_ids" in input.keys(): |
|
|
|
|
|
text_output = self.encoder.get_text_encoder()( |
|
|
input["input_ids"], |
|
|
attention_mask=input["attention_mask"], |
|
|
return_dict=True, |
|
|
mode="text", |
|
|
) |
|
|
text_embeds = text_output.last_hidden_state |
|
|
pooled_text_embeds = text_embeds[:, 0] |
|
|
pooled_output = self.encoder.text_proj(pooled_text_embeds) |
|
|
pooled_output /= pooled_output.norm(dim=-1, keepdim=True) |
|
|
return pooled_output |
|
|
else: |
|
|
_, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) |
|
|
vfeat = self.encoder.vision_proj(vfeat) |
|
|
vfeat /= vfeat.norm(dim=-1, keepdim=True) |
|
|
return vfeat |
|
|
elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]: |
|
|
|
|
|
texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] |
|
|
images = [] |
|
|
for imgs in input['images']: |
|
|
|
|
|
if isinstance(imgs, list): |
|
|
imgs = imgs[len(imgs) // 2] |
|
|
assert not isinstance(imgs, list) |
|
|
images.append(imgs) |
|
|
else: |
|
|
images.append(imgs) |
|
|
pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) |
|
|
return pooled_output |
|
|
elif getattr(self, "model_backbone", None) == COLPALI: |
|
|
pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) |
|
|
return pooled_output |
|
|
elif getattr(self, "model_backbone", None) == LLAVA_NEXT: |
|
|
input['pixel_values'] = input['pixel_values'].squeeze(dim=1) |
|
|
input['image_sizes'] = input['image_sizes'].squeeze(dim=1) |
|
|
hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) |
|
|
hidden_states = hidden_states.hidden_states[-1] |
|
|
pooled_output = self._pooling(hidden_states, input['attention_mask']) |
|
|
return pooled_output |
|
|
elif getattr(self, "model_backbone", None) in {QWEN2_VL_GP, QWEN2_5_VL_GP}: |
|
|
|
|
|
|
|
|
|
|
|
outputs = self.encoder( |
|
|
**input, |
|
|
return_dict=True, |
|
|
output_hidden_states=True, |
|
|
do_selection=True, |
|
|
delay_selection=False, |
|
|
use_cache=False, |
|
|
use_ref_masks=False, |
|
|
) |
|
|
self._last_image_token_bool_masks = getattr(outputs, "image_token_bool_masks", None) |
|
|
self._last_image_token_mask_logits = getattr(outputs, "image_token_mask_logits", None) |
|
|
self._last_le_loss = getattr(outputs, "le_loss", None) |
|
|
|
|
|
hs = outputs.hidden_states |
|
|
last_hidden = hs[-1] if isinstance(hs, (list, tuple)) else hs |
|
|
attn_mask = getattr(outputs, "attention_mask", None) |
|
|
if attn_mask is None: |
|
|
attn_mask = input["attention_mask"] |
|
|
attn = attn_mask |
|
|
pooled_output = self._pooling(last_hidden, attn) |
|
|
return pooled_output |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
outputs = self.encoder(**input, return_dict=True, output_hidden_states=True) |
|
|
|
|
|
img_masks = getattr(outputs, "image_token_bool_masks", None) |
|
|
if img_masks is None: |
|
|
try: |
|
|
img_masks = outputs.get("image_token_bool_masks", None) |
|
|
except Exception: |
|
|
pass |
|
|
self._last_image_token_bool_masks = img_masks |
|
|
|
|
|
hs = outputs.hidden_states |
|
|
if isinstance(hs, (list, tuple)): |
|
|
last_hidden = hs[-1] |
|
|
else: |
|
|
last_hidden = hs |
|
|
attn = getattr(outputs, "attention_mask", None) |
|
|
if attn is None: |
|
|
attn = input["attention_mask"] |
|
|
|
|
|
pooled_output = self._pooling(last_hidden, attn) |
|
|
return pooled_output |
|
|
|
|
|
def _pooling(self, last_hidden_state, attention_mask): |
|
|
if self.pooling == 'last' or self.pooling == 'eos': |
|
|
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
|
|
batch_size = last_hidden_state.shape[0] |
|
|
if left_padding: |
|
|
|
|
|
reps = last_hidden_state[torch.arange(batch_size), -1, :] |
|
|
else: |
|
|
|
|
|
eos_indices = attention_mask.sum(dim=1) - 1 |
|
|
|
|
|
reps = last_hidden_state[ |
|
|
torch.arange(batch_size, device=last_hidden_state.device), eos_indices] |
|
|
else: |
|
|
raise NotImplementedError |
|
|
if self.normalize: |
|
|
reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
|
|
return reps |
|
|
|
|
|
@classmethod |
|
|
def build(cls, model_args: ModelArguments, **kwargs): |
|
|
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
|
|
model_backbone = get_backbone_name(hf_config=config) |
|
|
print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}') |
|
|
|
|
|
if model_backbone == PHI3V: |
|
|
config._attn_implementation = "eager" |
|
|
config.padding_side = "right" |
|
|
config.use_cache = False |
|
|
base_model = Phi3VForCausalLM.from_pretrained( |
|
|
model_args.model_name, |
|
|
config=config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
elif model_backbone == LLAVA_NEXT: |
|
|
config.use_cache = False |
|
|
config.padding_side = "left" |
|
|
base_model = LlavaNextForConditionalGeneration.from_pretrained( |
|
|
model_args.model_name, |
|
|
config=config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
elif model_backbone in [QWEN2_VL, QWEN2_5_VL]: |
|
|
config._attn_implementation = "flash_attention_2" |
|
|
config.padding_side = "left" |
|
|
config.use_cache = False |
|
|
base_model = backbone2model[model_backbone].from_pretrained( |
|
|
model_args.model_name, |
|
|
config=config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
elif model_backbone in [QWEN2_VL_GP, QWEN2_5_VL_GP]: |
|
|
config._attn_implementation = "flash_attention_2" |
|
|
config.padding_side = "left" |
|
|
config.use_cache = False |
|
|
base_model = backbone2model[model_backbone].from_pretrained( |
|
|
model_args.model_name, |
|
|
config=config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]: |
|
|
config._attn_implementation = "flash_attention_2" |
|
|
config.padding_side = "left" |
|
|
config.use_cache = False |
|
|
|
|
|
from .utils import parse_layer_type |
|
|
lm_qwen_layer = 28 |
|
|
vis_qwen_layer = 32 |
|
|
lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer) |
|
|
vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer) |
|
|
|
|
|
base_model = backbone2model[model_backbone].from_pretrained( |
|
|
model_args.model_name, |
|
|
config=config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
lm_skip_layer=lm_skip_layer, |
|
|
vis_skip_layer=vis_skip_layer, |
|
|
) |
|
|
else: |
|
|
config.use_cache = False |
|
|
base_model = cls.TRANSFORMER_CLS.from_pretrained( |
|
|
model_args.model_name, **kwargs, config=config, |
|
|
attn_implementation="flash_attention_2", |
|
|
torch_dtype=torch.bfloat16, |
|
|
trust_remote_code=True) |
|
|
|
|
|
if model_args.lora: |
|
|
print_master(f'Loading lora adapter from {base_model}') |
|
|
lora_config = LoraConfig( |
|
|
r=model_args.lora_r, |
|
|
lora_alpha=model_args.lora_alpha, |
|
|
target_modules=model_args.lora_target_modules.split(','), |
|
|
lora_dropout=model_args.lora_dropout, |
|
|
init_lora_weights="gaussian", |
|
|
use_dora=True, |
|
|
inference_mode=False |
|
|
) |
|
|
lora_model = get_peft_model(base_model, lora_config) |
|
|
model = cls( |
|
|
encoder=lora_model, |
|
|
pooling=model_args.pooling, |
|
|
normalize=model_args.normalize, |
|
|
temperature=model_args.temperature |
|
|
) |
|
|
else: |
|
|
model = cls( |
|
|
encoder=base_model, |
|
|
pooling=model_args.pooling, |
|
|
normalize=model_args.normalize, |
|
|
temperature=model_args.temperature |
|
|
) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs): |
|
|
|
|
|
model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name |
|
|
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) |
|
|
if not hasattr(model_args, "model_backbone") or not model_args.model_backbone: |
|
|
model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type) |
|
|
setattr(model_args, 'model_backbone', model_backbone) |
|
|
|
|
|
|
|
|
print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}') |
|
|
if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}: |
|
|
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
|
|
config._attn_implementation = "flash_attention_2" |
|
|
config.vision_config._attn_implementation = "flash_attention_2" |
|
|
base_model = backbone2model[model_args.model_backbone].from_pretrained( |
|
|
model_args.model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
config=config |
|
|
) |
|
|
elif model_args.model_backbone in {QWEN2_5_VL_GP}: |
|
|
new_modules_dir = getattr(model_args, "new_modules_dir", None) or kwargs.get("new_modules_dir", None) |
|
|
|
|
|
if new_modules_dir: |
|
|
import os |
|
|
from huggingface_hub import hf_hub_download |
|
|
def _load_gp_config_from_dir_or_repo(repo_or_dir: str) -> Qwen2_5_VL_GPConfig: |
|
|
|
|
|
if os.path.isdir(repo_or_dir): |
|
|
cfg_path = os.path.join(repo_or_dir, "config.json") |
|
|
if not os.path.exists(cfg_path): |
|
|
raise FileNotFoundError(f"config.json not found in {repo_or_dir}") |
|
|
else: |
|
|
|
|
|
cfg_path = hf_hub_download(repo_id=repo_or_dir, filename="config.json") |
|
|
|
|
|
return Qwen2_5_VL_GPConfig.from_json_file(cfg_path) |
|
|
gp_cfg = _load_gp_config_from_dir_or_repo(new_modules_dir) |
|
|
gp_cfg._attn_implementation = "flash_attention_2" |
|
|
gp_cfg.padding_side = "left" |
|
|
gp_cfg.use_cache = False |
|
|
try: |
|
|
gp_cfg.vision_config._attn_implementation = "flash_attention_2" |
|
|
except Exception: |
|
|
if isinstance(gp_cfg.vision_config, dict): |
|
|
gp_cfg.vision_config["_attn_implementation"] = "flash_attention_2" |
|
|
|
|
|
base_model = backbone2model[model_args.model_backbone].from_pretrained( |
|
|
model_args.model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
config=gp_cfg, |
|
|
) |
|
|
|
|
|
|
|
|
base_model.load_new_modules(new_modules_dir) |
|
|
|
|
|
try: |
|
|
print_master(f"[GP] attn_fuser = {type(base_model.attn_fuser).__name__}") |
|
|
except Exception: |
|
|
print_master("[GP] attn_fuser not found (unexpected).") |
|
|
|
|
|
else: |
|
|
|
|
|
base_cfg = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
|
|
gp_cfg = Qwen2_5_VL_GPConfig(**base_cfg.to_dict()) |
|
|
gp_cfg._attn_implementation = "flash_attention_2" |
|
|
try: |
|
|
gp_cfg.vision_config._attn_implementation = "flash_attention_2" |
|
|
except Exception: |
|
|
if isinstance(gp_cfg.vision_config, dict): |
|
|
gp_cfg.vision_config["_attn_implementation"] = "flash_attention_2" |
|
|
gp_cfg.padding_side = "left" |
|
|
gp_cfg.use_cache = False |
|
|
|
|
|
L = getattr(base_cfg, "num_hidden_layers", None) or 28 |
|
|
gp_cfg.reduce_layer = min(max(L - 2, 1), L - 1) |
|
|
gp_cfg.selected_layers = [max(gp_cfg.reduce_layer - 2, 0), gp_cfg.reduce_layer] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_model = backbone2model[model_args.model_backbone].from_pretrained( |
|
|
model_args.model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
config=gp_cfg, |
|
|
) |
|
|
elif model_args.model_backbone == QWEN2_VL_GP: |
|
|
import os |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
from transformers.utils import is_flash_attn_2_available |
|
|
|
|
|
new_modules_dir = getattr(model_args, "new_modules_dir", None) or kwargs.get("new_modules_dir", None) |
|
|
if not new_modules_dir: |
|
|
raise ValueError("new_modules_dir is required for QWEN2_VL_GP") |
|
|
|
|
|
|
|
|
if os.path.isdir(new_modules_dir): |
|
|
local_new_modules_dir = new_modules_dir |
|
|
else: |
|
|
local_new_modules_dir = snapshot_download(repo_id=new_modules_dir) |
|
|
|
|
|
|
|
|
gp_cfg_path = os.path.join(local_new_modules_dir, "config.json") |
|
|
if not os.path.exists(gp_cfg_path): |
|
|
raise FileNotFoundError(f"config.json not found in {local_new_modules_dir}") |
|
|
gp_cfg = Qwen2VL_GPConfig.from_json_file(gp_cfg_path) |
|
|
|
|
|
|
|
|
base_cfg = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
|
|
|
|
|
|
|
|
if hasattr(gp_cfg, "hidden_size") and hasattr(base_cfg, "hidden_size"): |
|
|
if gp_cfg.hidden_size != base_cfg.hidden_size: |
|
|
raise ValueError( |
|
|
f"[Mismatch] base_model hidden_size={base_cfg.hidden_size} " |
|
|
f"!= gp_modules hidden_size={gp_cfg.hidden_size}. " |
|
|
f"请更换为匹配的 base 模型或对应的 new_modules_dir。" |
|
|
) |
|
|
|
|
|
|
|
|
base_cfg._attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "sdpa" |
|
|
base_cfg.padding_side = "left" |
|
|
base_cfg.use_cache = False |
|
|
if hasattr(base_cfg, "vision_config"): |
|
|
vc = base_cfg.vision_config |
|
|
if isinstance(vc, dict): |
|
|
vc["_attn_implementation"] = base_cfg._attn_implementation |
|
|
else: |
|
|
setattr(vc, "_attn_implementation", base_cfg._attn_implementation) |
|
|
|
|
|
|
|
|
ModelClass = backbone2model[model_args.model_backbone] |
|
|
base_model = ModelClass.from_pretrained( |
|
|
model_args.model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
config=base_cfg, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
|
|
|
base_model.load_new_modules(local_new_modules_dir) |
|
|
|
|
|
try: |
|
|
print_master(f"[GP] attn_fuser = {type(base_model.attn_fuser).__name__}") |
|
|
except Exception: |
|
|
print_master("[GP] attn_fuser not found (unexpected).") |
|
|
|
|
|
elif model_args.model_backbone == PHI3V: |
|
|
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
|
|
config.use_cache = False |
|
|
config.padding_side = "right" |
|
|
base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config, |
|
|
torch_dtype=torch.bfloat16, trust_remote_code=True) |
|
|
base_model.padding_side = "right" |
|
|
elif model_args.model_backbone == INTERNVIDEO2: |
|
|
print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}') |
|
|
config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", |
|
|
trust_remote_code=True) |
|
|
base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config, |
|
|
trust_remote_code=True) |
|
|
elif model_args.model_backbone == GME: |
|
|
base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor']) |
|
|
setattr(base_model, 'config', config) |
|
|
elif model_args.model_backbone == LamRA: |
|
|
base_model = LamRAQwen2VL(model_args.model_name) |
|
|
setattr(base_model, 'config', config) |
|
|
elif model_args.model_backbone == LamRA_QWEN2_5: |
|
|
base_model = LamRAQwen25VL(model_args.model_name) |
|
|
setattr(base_model, 'config', config) |
|
|
elif model_args.model_backbone == COLPALI: |
|
|
base_model = ColPali.from_pretrained(model_args.model_name) |
|
|
setattr(base_model, 'config', config) |
|
|
else: |
|
|
|
|
|
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
|
|
config.use_cache = False |
|
|
base_model = cls.TRANSFORMER_CLS.from_pretrained( |
|
|
model_name_or_path, **kwargs, config=config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
trust_remote_code=True) |
|
|
|
|
|
|
|
|
if model_args.lora: |
|
|
print_master(f'Loading LoRA from {model_name_or_path}') |
|
|
lora_config = LoraConfig.from_pretrained(model_name_or_path) |
|
|
lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable) |
|
|
lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable) |
|
|
if not is_trainable: |
|
|
lora_model = lora_model.merge_and_unload() |
|
|
model = cls( |
|
|
encoder=lora_model, |
|
|
pooling=model_args.pooling, |
|
|
normalize=model_args.normalize, |
|
|
temperature=model_args.temperature |
|
|
) |
|
|
else: |
|
|
model = cls( |
|
|
encoder=base_model, |
|
|
pooling=model_args.pooling, |
|
|
normalize=model_args.normalize, |
|
|
temperature=model_args.temperature |
|
|
) |
|
|
|
|
|
model.model_backbone = model_args.model_backbone |
|
|
return model |
|
|
|
|
|
def save(self, output_dir: str): |
|
|
self.encoder.save_pretrained(output_dir) |
|
|
|
|
|
def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs): |
|
|
qry_reps = self.encode_input(qry) if qry else None |
|
|
tgt_reps = self.encode_input(tgt) if tgt else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if qry_reps is None or tgt_reps is None: |
|
|
|
|
|
out = {"qry_reps": qry_reps, "tgt_reps": tgt_reps} |
|
|
|
|
|
img_masks = getattr(self, "_last_image_token_bool_masks", None) |
|
|
if img_masks is not None: |
|
|
out["image_token_bool_masks"] = img_masks |
|
|
return out |
|
|
|
|
|
if self.is_ddp: |
|
|
all_qry_reps = self._dist_gather_tensor(qry_reps) |
|
|
all_tgt_reps = self._dist_gather_tensor(tgt_reps) |
|
|
else: |
|
|
all_qry_reps = qry_reps |
|
|
all_tgt_reps = tgt_reps |
|
|
|
|
|
scores = self.compute_similarity(all_qry_reps, all_tgt_reps) |
|
|
scores = scores.view(all_qry_reps.size(0), -1) |
|
|
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) |
|
|
target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0)) |
|
|
loss = self.cross_entropy(scores / self.temperature, target) |
|
|
if self.is_ddp: |
|
|
loss = loss * self.world_size |
|
|
|
|
|
return loss |
|
|
|
|
|
def _dist_gather_tensor(self, t: Tensor): |
|
|
t = t.contiguous() |
|
|
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] |
|
|
dist.all_gather(all_tensors, t) |
|
|
all_tensors[self.process_rank] = t |
|
|
all_tensors = torch.cat(all_tensors, dim=0) |
|
|
return all_tensors |
|
|
|
|
|
def compute_similarity(self, q_reps, p_reps): |
|
|
return torch.matmul(q_reps, p_reps.transpose(0, 1)) |
|
|
|
|
|
class AsymMMEBModel(nn.Module): |
|
|
def __init__(self, encoder_qry: PreTrainedModel, encoder_tgt: PreTrainedModel, |
|
|
pooling: str = 'last', normalize: bool = False, temperature: float = 0.02): |
|
|
super().__init__() |
|
|
self.encoder_qry = encoder_qry |
|
|
self.encoder_tgt = encoder_tgt |
|
|
self.pooling = pooling |
|
|
self.normalize = normalize |
|
|
self.temperature = temperature |
|
|
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') |
|
|
self.is_ddp = dist.is_initialized() |
|
|
self.world_size = dist.get_world_size() if self.is_ddp else 1 |
|
|
|
|
|
|
|
|
self.gp_do_selection_qry = False |
|
|
|
|
|
|
|
|
self._last_image_token_bool_masks = None |
|
|
self._last_image_token_mask_logits = None |
|
|
self._last_le_loss = None |
|
|
|
|
|
def _pool(self, last_hidden_state, attention_mask): |
|
|
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
|
|
B = last_hidden_state.shape[0] |
|
|
if left_padding: |
|
|
reps = last_hidden_state[torch.arange(B), -1, :] |
|
|
else: |
|
|
idx = attention_mask.sum(dim=1) - 1 |
|
|
reps = last_hidden_state[torch.arange(B, device=last_hidden_state.device), idx] |
|
|
if self.normalize: |
|
|
reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
|
|
return reps |
|
|
|
|
|
def encode_qry(self, inp: Dict[str, torch.Tensor]): |
|
|
|
|
|
do_sel = self.gp_do_selection_qry |
|
|
out = self.encoder_qry( |
|
|
**inp, |
|
|
return_dict=True, |
|
|
output_hidden_states=True, |
|
|
do_selection=do_sel, |
|
|
delay_selection=False if do_sel else True, |
|
|
use_cache=False, |
|
|
use_ref_masks=False, |
|
|
) |
|
|
self._last_image_token_bool_masks = getattr(out, "image_token_bool_masks", None) |
|
|
self._last_image_token_mask_logits = getattr(out, "image_token_mask_logits", None) |
|
|
self._last_le_loss = getattr(out, "le_loss", None) |
|
|
|
|
|
hs = out.hidden_states |
|
|
last = hs[-1] if isinstance(hs, (list, tuple)) else hs |
|
|
attn = getattr(out, "attention_mask", None) or inp["attention_mask"] |
|
|
rep = self._pool(last, attn) |
|
|
return rep |
|
|
|
|
|
def encode_tgt(self, inp: Dict[str, torch.Tensor]): |
|
|
out = self.encoder_tgt( |
|
|
**inp, |
|
|
return_dict=True, |
|
|
output_hidden_states=True, |
|
|
use_cache=False, |
|
|
) |
|
|
hs = out.hidden_states |
|
|
last = hs[-1] if isinstance(hs, (list, tuple)) else hs |
|
|
attn = getattr(out, "attention_mask", None) or inp["attention_mask"] |
|
|
rep = self._pool(last, attn) |
|
|
return rep |
|
|
|
|
|
def forward(self, qry: Dict[str, torch.Tensor] = None, tgt: Dict[str, torch.Tensor] = None): |
|
|
qry_reps = self.encode_qry(qry) if qry else None |
|
|
tgt_reps = self.encode_tgt(tgt) if tgt else None |
|
|
if qry_reps is None or tgt_reps is None: |
|
|
return {"qry_reps": qry_reps, "tgt_reps": tgt_reps} |
|
|
|
|
|
if self.is_ddp: |
|
|
qry_all = self._gather(qry_reps) |
|
|
tgt_all = self._gather(tgt_reps) |
|
|
else: |
|
|
qry_all, tgt_all = qry_reps, tgt_reps |
|
|
|
|
|
scores = torch.matmul(qry_all, tgt_all.transpose(0, 1)).view(qry_all.size(0), -1) |
|
|
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) |
|
|
target = target * (qry_all.size(0) // tgt_all.size(0)) |
|
|
loss = self.cross_entropy(scores / self.temperature, target) |
|
|
if self.is_ddp: |
|
|
loss = loss * self.world_size |
|
|
return loss |
|
|
|
|
|
def _gather(self, t: torch.Tensor): |
|
|
t = t.contiguous() |
|
|
outs = [torch.empty_like(t) for _ in range(self.world_size)] |
|
|
dist.all_gather(outs, t) |
|
|
outs[dist.get_rank()] = t |
|
|
return torch.cat(outs, dim=0) |
|
|
|
|
|
@classmethod |
|
|
def load_asym(cls, model_args, processor=None, processor_tgt=None, is_trainable=True): |
|
|
|
|
|
from src.model.vlm_backbone.qwen2_5_vl_gp.configuration import Qwen2_5_VL_GPConfig |
|
|
from src.model.vlm_backbone.qwen2_5_vl_gp.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration_GP as GPModel |
|
|
from transformers import AutoConfig |
|
|
new_modules_dir = getattr(model_args, "new_modules_dir", None) |
|
|
assert new_modules_dir is not None, "new_modules_dir 不能为空(query 端 GP 模块)" |
|
|
|
|
|
gp_cfg = Qwen2_5_VL_GPConfig.from_pretrained(new_modules_dir) |
|
|
gp_cfg._attn_implementation = "flash_attention_2" |
|
|
gp_cfg.padding_side = "left" |
|
|
gp_cfg.use_cache = False |
|
|
encoder_qry = GPModel.from_pretrained( |
|
|
model_args.model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
config=gp_cfg, |
|
|
) |
|
|
encoder_qry.load_new_modules(new_modules_dir) |
|
|
|
|
|
|
|
|
from src.model.vlm_backbone.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration as VanillaModel |
|
|
cfg_tgt = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
|
|
cfg_tgt._attn_implementation = "flash_attention_2" |
|
|
cfg_tgt.use_cache = False |
|
|
|
|
|
if hasattr(cfg_tgt, "vision_config"): |
|
|
try: |
|
|
cfg_tgt.vision_config._attn_implementation = "flash_attention_2" |
|
|
except Exception: |
|
|
pass |
|
|
encoder_tgt = VanillaModel.from_pretrained( |
|
|
model_args.model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
config=cfg_tgt, |
|
|
) |
|
|
|
|
|
|
|
|
if getattr(model_args, "image_encoder_freeze", True): |
|
|
for p in encoder_qry.parameters(): p.requires_grad = False |
|
|
for p in encoder_tgt.parameters(): p.requires_grad = False |
|
|
|
|
|
if hasattr(encoder_qry, "new_modules_to_be_saved"): |
|
|
for mod in encoder_qry.new_modules_to_be_saved().values(): |
|
|
if isinstance(mod, torch.nn.Parameter): |
|
|
mod.requires_grad = True |
|
|
else: |
|
|
for p in mod.parameters(): p.requires_grad = True |
|
|
|
|
|
return cls(encoder_qry=encoder_qry, encoder_tgt=encoder_tgt, |
|
|
pooling=model_args.pooling, normalize=model_args.normalize, |
|
|
temperature=model_args.temperature) |