code_SAS_VLM2Vec / src /classifier_utils_V5.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class EarlyExitClassifier(nn.Module):
"""
V5 版本分类器:集成轻量级 TTA (LayerNorm) 和 Log1p 特征变换
"""
def __init__(self, input_dim=27, hidden_dim=128, embedding_dim=0, dropout_prob=0.2):
super().__init__()
# =====================================================================
# 【改动 1:TTA 核心实现 - LayerNorm】
# 使用 LayerNorm 替代 BatchNorm。
# LayerNorm 在推理时会对每个样本独立计算统计量进行归一化,
# 从而适应不同数据集的特征分布差异 (即 Test-Time Adaptation)。
# =====================================================================
self.scalar_ln = nn.LayerNorm(input_dim)
self.modality_emb = nn.Embedding(2, 4)
self.use_embedding = embedding_dim > 0
if self.use_embedding:
# 语义特征投影层
self.emb_proj = nn.Sequential(
nn.Linear(embedding_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2), # 这里也用 LN 保持一致性
nn.ReLU()
)
# scalar(经过LN) + modality(4) + embedding(hidden/2)
total_input_dim = input_dim + 4 + (hidden_dim // 2)
else:
total_input_dim = input_dim + 4
# 主 MLP
self.mlp = nn.Sequential(
nn.Linear(total_input_dim, hidden_dim),
nn.LayerNorm(hidden_dim), # 保持使用 LN
nn.ReLU(),
nn.Dropout(dropout_prob),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout_prob),
nn.Linear(hidden_dim // 2, 1),
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None: nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0, std=0.02)
def forward(self, scalar_feats, modality_idx, qry_emb=None):
# scalar_feats: [B, input_dim] (should be Float32)
# =====================================================================
# 【改动 2:特征预处理 - Log1p 变换】
# 对 Margin, Entropy 等偏态分布特征进行对数变换,拉伸低值区间的区分度。
#保留符号,对绝对值做 log1p
scalar_feats_log = torch.sign(scalar_feats) * torch.log1p(torch.abs(scalar_feats))
# =====================================================================
# 1. 处理标量特征 (应用 TTA LayerNorm)
s_feat = self.scalar_ln(scalar_feats_log)
# 2. 处理模态特征
m_feat = self.modality_emb(modality_idx) # [B, 4]
features = [s_feat, m_feat]
# 3. 处理语义特征 (如果有)
if self.use_embedding:
if qry_emb is None:
raise ValueError("Classifier init with embedding_dim > 0 but forward received None")
# 确保输入是 float32
if qry_emb.dtype != torch.float32:
qry_emb = qry_emb.float()
e_feat = self.emb_proj(qry_emb)
features.append(e_feat)
# 拼接
x = torch.cat(features, dim=1)
# MLP 输出 Logits
logits = self.mlp(x) # [B, 1]
return logits