code_SAS_VLM2Vec / src /classifier_utils_V3.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
import torch
import torch.nn as nn
import logging
logger = logging.getLogger(__name__)
class EarlyExitClassifier(nn.Module):
def __init__(self, input_dim=27, hidden_dim=128, embedding_dim=0):
"""
Args:
input_dim: 统计特征维度 (27)
hidden_dim: 隐藏层大小
embedding_dim: 如果 > 0,则接收 backbone 的 hidden_states 作为输入
"""
super().__init__()
# === 改进点1: 输入归一化 (关键!) ===
# 这能解决不同任务 Score 分布不一样导致阈值无法统一的问题
self.scalar_bn = nn.BatchNorm1d(input_dim)
# 模态 Embedding
self.modality_emb = nn.Embedding(2, 4)
# === 改进点2: 引入语义 Embedding ===
self.use_embedding = embedding_dim > 0
if self.use_embedding:
# 将高维 Embedding (如 2560) 压缩,避免它主导整个网络
self.emb_proj = nn.Sequential(
nn.Linear(embedding_dim, 64),
nn.LayerNorm(64),
nn.ReLU()
)
# 总维度 = 统计特征(27) + 模态(4) + 语义压缩(64)
total_input_dim = input_dim + 4 + 64
else:
total_input_dim = input_dim + 4
# MLP 主体
self.mlp = nn.Sequential(
nn.Linear(total_input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim), # 中间层也加 BN 防止梯度消失
nn.ReLU(),
nn.Dropout(0.2), # 增加 Dropout 防止过拟合
nn.Linear(hidden_dim, 1),
)
def forward(self, scalar_feats, modality_idx, qry_emb=None):
"""
scalar_feats: [B, 27]
modality_idx: [B]
qry_emb: [B, hidden_size] (New!)
"""
# 1. 归一化统计特征
s_feat = self.scalar_bn(scalar_feats)
# 2. 模态特征
m_feat = self.modality_emb(modality_idx)
features = [s_feat, m_feat]
# 3. 处理语义特征
if self.use_embedding:
if qry_emb is None:
raise ValueError("Model initialized with embedding_dim > 0 but qry_emb is None")
e_feat = self.emb_proj(qry_emb)
features.append(e_feat)
# 拼接
x = torch.cat(features, dim=1)
logits = self.mlp(x)
return logits