code_SAS_VLM2Vec / src /model /model_multi_layer_distill.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
from typing import Dict
import torch
import torch.distributed as dist
from torch import nn, Tensor
import torch.nn.functional as F # 如果文件顶部没引入的话
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig
from peft import LoraConfig, get_peft_model, PeftModel
from src.model.processor import QWEN2_5_VL_TOKENSELECTION
from src.arguments_multi_layer import ModelArguments, TrainingArguments
from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \
backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V
from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \
QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI
from src.model.baseline_backbone.colpali import ColPali
from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL
from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL
from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL
from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM
from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration
from transformers import modeling_utils
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise']
class MMEBModel(nn.Module):
TRANSFORMER_CLS = AutoModelForCausalLM
def __init__(self,
encoder: PreTrainedModel,
pooling: str = 'last',
normalize: bool = False,
temperature: float = 0.02,
):
super().__init__()
self.config = encoder.config
self.encoder = encoder
self.pooling = pooling
self.normalize = normalize
self.temperature = temperature
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.is_ddp = dist.is_initialized()
if self.is_ddp:
self.process_rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.layer_indices = [20, -1]
self.supervise_layers = [20, -1] # -1 必须表示最后一层
self.supervise_weights = [0.15, 0.85] # 与 supervise_layers 对齐
@property
def device(self):
# 尽量稳妥地拿到设备
try:
return next(self.encoder.parameters()).device
except StopIteration:
try:
return next(self.parameters()).device
except StopIteration:
return torch.device("cpu")
def _has_image(self, batch_input):
"""
基于输入是否包含像素张量来判断是否含图像。
True:存在 'pixel_values' 且非None且元素数>0;或存在 'images'(部分backbone)
False:否则
"""
B = None
if 'attention_mask' in batch_input:
B = batch_input['attention_mask'].shape[0]
elif 'input_ids' in batch_input:
B = batch_input['input_ids'].shape[0]
has_img = False
if 'pixel_values' in batch_input and batch_input['pixel_values'] is not None:
# pixel_values 形状可能是 [B, ...] 或 [B, 1, ...]
pv = batch_input['pixel_values']
has_img = pv.numel() > 0
if B is None:
B = pv.shape[0]
elif 'images' in batch_input and batch_input['images'] is not None:
has_img = True # 列表/占位,视为包含图像
if B is None:
# 兜底:看作不含图像
return torch.zeros(1, dtype=torch.float32, device=self.encoder.device)
val = 1.0 if has_img else 0.0
return torch.full((B,), fill_value=val, dtype=torch.float32, device=self.encoder.device)
@staticmethod
def _masked_mean(loss_vec: Tensor, weight_mask: Tensor) -> Tensor:
denom = torch.clamp(weight_mask.sum(), min=1.0)
return (loss_vec * weight_mask).sum() / denom
def _normalize_layers(self, hs_len: int, layers: list[int]) -> list[int]:
Lmax = hs_len - 1
out = []
for idx in layers:
if idx < 0:
idx = hs_len + idx
idx = max(1, min(idx, Lmax))
out.append(idx)
if (hs_len - 1) not in out:
out.append(hs_len - 1)
return out
def _encode_multi(self, input):
"""
通用多层编码:返回 [B, K, D],K=len(self.supervise_layers,经规范化且包含最后一层)。
"""
mb = getattr(self, "model_backbone", None)
def norm(x):
return F.normalize(x, p=2, dim=-1) if self.normalize else x
# 支持 hidden_states 的通用分支
if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI]:
out = self.encoder(**input, return_dict=True, output_hidden_states=True)
hs = out.hidden_states # list/tuple, len = L+1
idxs = self._normalize_layers(len(hs), list(dict.fromkeys(self.supervise_layers))) # 去重保序
reps = []
for idx in idxs:
r = self._pooling(hs[idx], input['attention_mask'])
reps.append(norm(r))
return torch.stack(reps, dim=1) # [B, K, D]
# LLAVA_NEXT:仍可拿 hidden_states
if mb == LLAVA_NEXT:
input = dict(input)
input['pixel_values'] = input['pixel_values'].squeeze(dim=1)
input['image_sizes'] = input['image_sizes'].squeeze(dim=1)
out = self.encoder(**input, return_dict=True, output_hidden_states=True)
hs = out.hidden_states
idxs = self._normalize_layers(len(hs), list(dict.fromkeys(self.supervise_layers)))
reps = []
for idx in idxs:
r = self._pooling(hs[idx], input['attention_mask'])
reps.append(norm(r))
return torch.stack(reps, dim=1)
# 其他不支持 hidden_states 的backbone:退化为重复最后一层
last = self.encode_input(input) # [B, D]
last = norm(last)
K = len(self.supervise_layers)
return torch.stack([last for _ in range(K)], dim=1) # [B, K, D]
# def encode_input(self, input):
def encode_input(self, input, layer_indices=None):
if getattr(self, "model_backbone", None) == INTERNVIDEO2:
if "input_ids" in input.keys():
# text side
text_output = self.encoder.get_text_encoder()(
input["input_ids"],
attention_mask=input["attention_mask"],
return_dict=True,
mode="text",
)
text_embeds = text_output.last_hidden_state
pooled_text_embeds = text_embeds[:, 0]
pooled_output = self.encoder.text_proj(pooled_text_embeds)
pooled_output /= pooled_output.norm(dim=-1, keepdim=True)
return pooled_output
else:
_, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True)
vfeat = self.encoder.vision_proj(vfeat)
vfeat /= vfeat.norm(dim=-1, keepdim=True)
return vfeat
elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]:
# pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen
images = []
for imgs in input['images']:
# if multi images are given, select the middle frame only
if isinstance(imgs, list):
imgs = imgs[len(imgs) // 2]
assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list
images.append(imgs)
else:
images.append(imgs)
pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images)
return pooled_output
elif getattr(self, "model_backbone", None) == COLPALI:
pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
return pooled_output
elif getattr(self, "model_backbone", None) == LLAVA_NEXT:
input['pixel_values'] = input['pixel_values'].squeeze(dim=1)
input['image_sizes'] = input['image_sizes'].squeeze(dim=1)
hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True)
hidden_states = hidden_states.hidden_states[-1]
pooled_output = self._pooling(hidden_states, input['attention_mask'])
return pooled_output
else:
# 默认HF模型:支持 hidden_states
out = self.encoder(**input, return_dict=True, output_hidden_states=True)
hs_list = out.hidden_states
if layer_indices is None or isinstance(layer_indices, int):
h = hs_list[-1] if layer_indices is None else hs_list[layer_indices]
reps = self._pooling(h, input['attention_mask'])
return reps
else:
reps_list = []
for idx in layer_indices:
h = hs_list[idx]
r = self._pooling(h, input['attention_mask'])
reps_list.append(r)
reps = torch.stack(reps_list, dim=1) # [B, L, D]
return reps
def _pooling(self, last_hidden_state, attention_mask):
if self.pooling == 'last' or self.pooling == 'eos':
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
batch_size = last_hidden_state.shape[0]
if left_padding:
# Get the vectors at the last position
reps = last_hidden_state[torch.arange(batch_size), -1, :]
else:
# Calculate last 1 position in the original tensor
eos_indices = attention_mask.sum(dim=1) - 1
# Get the vectors at the last 1 position of each attention mask
reps = last_hidden_state[
torch.arange(batch_size, device=last_hidden_state.device), eos_indices]
else:
raise NotImplementedError
if self.normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
@classmethod
def build(cls, model_args: ModelArguments, **kwargs):
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
variant = getattr(config, "backbone_variant", None)
if variant == "layerprune":
model_backbone = "QWEN2_VL_LayerPrune"
else:
model_backbone = get_backbone_name(hf_config=config)
print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}')
# Loading the base model
if model_backbone == PHI3V:
config._attn_implementation = "eager"
config.padding_side = "right"
config.use_cache = False
base_model = Phi3VForCausalLM.from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone == LLAVA_NEXT:
config.use_cache = False
config.padding_side = "left"
base_model = LlavaNextForConditionalGeneration.from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in [QWEN2_VL, QWEN2_5_VL]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in ["QWEN2_VL_LayerPrune"]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
from .utils import parse_layer_type
lm_qwen_layer = 28
vis_qwen_layer = 32
lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer)
vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer)
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
lm_skip_layer=lm_skip_layer,
vis_skip_layer=vis_skip_layer,
)
else:
config.use_cache = False
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_args.model_name, **kwargs, config=config,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
trust_remote_code=True)
if model_args.lora:
print_master(f'Loading lora adapter from {base_model}')
lora_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
target_modules=model_args.lora_target_modules.split(','),
lora_dropout=model_args.lora_dropout,
init_lora_weights="gaussian",
use_dora=True,
inference_mode=False
)
lora_model = get_peft_model(base_model, lora_config)
model = cls(
encoder=lora_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
else:
model = cls(
encoder=base_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
# 在 build(...) 末尾(return model 前)添加
def _parse_list(val, tp=float):
if val is None: return None
if isinstance(val, (list, tuple)): return [tp(x) for x in val]
s = str(val).strip()
if s == "": return None
return [tp(v.strip()) for v in s.split(",") if v.strip() != ""]
layers = _parse_list(getattr(model_args, "supervise_layers", None), tp=int)
weights = _parse_list(getattr(model_args, "supervise_weights", None), tp=float)
if layers is None:
# fallback 到旧的二层设置
layers = [getattr(model_args, 'dual_layer_idx', 20), -1]
if -1 not in layers:
layers = list(layers) + [-1] # 强制包含最后一层
if weights is None or len(weights) != len(layers):
# 若未提供或长度不匹配,则做一个合理默认:最后一层占大头
K = len(layers)
base = [1.0/(K-1)]*(K-1) if K>1 else [1.0]
weights = base + [max(0.0, 1.0 - sum(base))]
# 归一化
s = sum(max(0.0, w) for w in weights)
weights = [max(0.0, w)/s for w in weights]
setattr(model, 'supervise_layers', layers)
setattr(model, 'supervise_weights', weights)
# 新增:读取门控与蒸馏超参
setattr(model, 'gate_by_image', getattr(model_args, 'gate_by_image', True))
setattr(model, 'misalign_mid_ce', float(getattr(model_args, 'misalign_mid_ce', 0.0)))
setattr(model, 'distill_beta', float(getattr(model_args, 'distill_beta', 1.0)))
setattr(model, 'distill_on_aligned', bool(getattr(model_args, 'distill_on_aligned', False)))
return model
@classmethod
def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs):
# Loading the base model
model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
if not hasattr(model_args, "model_backbone") or not model_args.model_backbone:
model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}')
if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}:
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config._attn_implementation = "flash_attention_2"
config.vision_config._attn_implementation = "flash_attention_2"
base_model = backbone2model[model_args.model_backbone].from_pretrained(
model_args.model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
config=config
)
elif model_args.model_backbone == PHI3V:
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config.use_cache = False
config.padding_side = "right"
base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config,
torch_dtype=torch.bfloat16, trust_remote_code=True)
base_model.padding_side = "right"
elif model_args.model_backbone == INTERNVIDEO2:
print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}')
config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/",
trust_remote_code=True)
base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config,
trust_remote_code=True)
elif model_args.model_backbone == GME:
base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor'])
setattr(base_model, 'config', config)
elif model_args.model_backbone == LamRA:
base_model = LamRAQwen2VL(model_args.model_name)
setattr(base_model, 'config', config)
elif model_args.model_backbone == LamRA_QWEN2_5:
base_model = LamRAQwen25VL(model_args.model_name)
setattr(base_model, 'config', config)
elif model_args.model_backbone == COLPALI:
base_model = ColPali.from_pretrained(model_args.model_name)
setattr(base_model, 'config', config)
else:
# Loading external base model from HF
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config.use_cache = False
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_name_or_path, **kwargs, config=config,
torch_dtype=torch.bfloat16,
trust_remote_code=True)
# Building the model on top of the base
if model_args.lora:
print_master(f'Loading LoRA from {model_name_or_path}')
lora_config = LoraConfig.from_pretrained(model_name_or_path)
lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable)
lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable)
if not is_trainable:
lora_model = lora_model.merge_and_unload()
model = cls(
encoder=lora_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
else:
model = cls(
encoder=base_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
model.model_backbone = model_args.model_backbone
return model
def save(self, output_dir: str):
self.encoder.save_pretrained(output_dir)
def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs):
# GradCache:只给一侧 -> 返回多层表示
if qry is not None and tgt is None:
qry_reps = self._encode_multi(qry) # [B, K, D]
return {"qry_reps": qry_reps, "tgt_reps": None}
if tgt is not None and qry is None:
tgt_reps = self._encode_multi(tgt) # [B, K, D]
return {"qry_reps": None, "tgt_reps": tgt_reps}
# 非 GradCache:两侧同时给,直接算逐层配对的加权 CE
q_multi = self._encode_multi(qry) # [B, K, D]
p_multi = self._encode_multi(tgt) # [B, K, D]
# DDP gather
if self.is_ddp:
q_multi_all = self._dist_gather_tensor(q_multi) # [B*, K, D]
p_multi_all = self._dist_gather_tensor(p_multi) # [B*, K, D]
else:
q_multi_all, p_multi_all = q_multi, p_multi
Bglob, K, D = q_multi_all.shape
assert p_multi_all.shape[:2] == (Bglob, K), f"Shape mismatch: q {q_multi_all.shape}, p {p_multi_all.shape}"
target = torch.arange(Bglob, device=q_multi_all.device, dtype=torch.long)
w = torch.tensor(self.supervise_weights, dtype=torch.float32, device=q_multi_all.device)
w = torch.clamp(w, min=0)
w = w / max(w.sum().item(), 1e-8)
# 计算对齐/非对齐门控:同为含图像或同为不含图像 => aligned
# 先在本rank上做,再all_gather与 q_multi_all/p_multi_all 对齐
q_has_img_local = self._has_image(qry) # [B_local]
p_has_img_local = self._has_image(tgt) # [B_local]
if self.is_ddp:
q_has_img = self._dist_gather_tensor(q_has_img_local)
p_has_img = self._dist_gather_tensor(p_has_img_local)
else:
q_has_img, p_has_img = q_has_img_local, p_has_img_local
aligned_mask = (q_has_img == p_has_img).float() # [Bglob]
misaligned_mask = 1.0 - aligned_mask
loss = 0.0
last_idx = K - 1
# 1) 最后一层:始终用对比损失
logits_last = torch.matmul(q_multi_all[:, last_idx, :], p_multi_all[:, last_idx, :].transpose(0, 1)) / self.temperature
loss_last = self.cross_entropy(logits_last, target)
loss = loss + w[last_idx] * loss_last
# 2) 中间层:对齐→对比;非对齐→自蒸馏(可选极小对比)
for k in range(0, last_idx):
# 2.1 中间层对比(per-sample masked mean)
logits_k = torch.matmul(q_multi_all[:, k, :], p_multi_all[:, k, :].transpose(0, 1)) / self.temperature
loss_vec = torch.nn.functional.cross_entropy(logits_k, target, reduction='none') # [Bglob]
if getattr(self, 'gate_by_image', True):
# 对齐样本:权重=1;非对齐样本:权重=misalign_mid_ce(默认0)
weight_mask = aligned_mask + self.misalign_mid_ce * misaligned_mask
else:
# 不门控:全样本权重=1
weight_mask = torch.ones_like(aligned_mask)
mid_ce = self._masked_mean(loss_vec, weight_mask)
# 2.2 中间层自蒸馏(单样本,teacher stop-grad)
do_distill = (self.distill_beta is not None) and (self.distill_beta > 0.0)
if do_distill:
q_teacher = q_multi_all[:, last_idx, :].detach()
p_teacher = p_multi_all[:, last_idx, :].detach()
# 余弦相似度 -> (1 - cos)
dist_q = 1.0 - torch.nn.functional.cosine_similarity(q_multi_all[:, k, :], q_teacher, dim=-1) # [Bglob]
dist_p = 1.0 - torch.nn.functional.cosine_similarity(p_multi_all[:, k, :], p_teacher, dim=-1) # [Bglob]
dist_vec = dist_q + dist_p # [Bglob]
if getattr(self, 'gate_by_image', True):
if getattr(self, 'distill_on_aligned', False):
dist_mask = torch.ones_like(aligned_mask) # 对齐与非对齐都蒸馏
else:
dist_mask = misaligned_mask # 仅非对齐蒸馏
else:
dist_mask = torch.ones_like(aligned_mask)
mid_distill = self._masked_mean(dist_vec, dist_mask)
mid_total = mid_ce + self.distill_beta * mid_distill
else:
mid_total = mid_ce
loss = loss + w[k] * mid_total
if self.is_ddp:
loss = loss * self.world_size
return loss
def _dist_gather_tensor(self, t: Tensor):
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)
all_tensors[self.process_rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
def compute_similarity(self, q_reps, p_reps):
return torch.matmul(q_reps, p_reps.transpose(0, 1))