import torch import torch.nn as nn import torch.nn.functional as F class EarlyExitClassifier(nn.Module): """ V5 版本分类器:集成轻量级 TTA (LayerNorm) 和 Log1p 特征变换 """ def __init__(self, input_dim=27, hidden_dim=128, embedding_dim=0, dropout_prob=0.2): super().__init__() # ===================================================================== # 【改动 1:TTA 核心实现 - LayerNorm】 # 使用 LayerNorm 替代 BatchNorm。 # LayerNorm 在推理时会对每个样本独立计算统计量进行归一化, # 从而适应不同数据集的特征分布差异 (即 Test-Time Adaptation)。 # ===================================================================== self.scalar_ln = nn.LayerNorm(input_dim) self.modality_emb = nn.Embedding(2, 4) self.use_embedding = embedding_dim > 0 if self.use_embedding: # 语义特征投影层 self.emb_proj = nn.Sequential( nn.Linear(embedding_dim, hidden_dim // 2), nn.LayerNorm(hidden_dim // 2), # 这里也用 LN 保持一致性 nn.ReLU() ) # scalar(经过LN) + modality(4) + embedding(hidden/2) total_input_dim = input_dim + 4 + (hidden_dim // 2) else: total_input_dim = input_dim + 4 # 主 MLP self.mlp = nn.Sequential( nn.Linear(total_input_dim, hidden_dim), nn.LayerNorm(hidden_dim), # 保持使用 LN nn.ReLU(), nn.Dropout(dropout_prob), nn.Linear(hidden_dim, hidden_dim // 2), nn.LayerNorm(hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout_prob), nn.Linear(hidden_dim // 2, 1), ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0.0) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0, std=0.02) def forward(self, scalar_feats, modality_idx, qry_emb=None): # scalar_feats: [B, input_dim] (should be Float32) # ===================================================================== # 【改动 2:特征预处理 - Log1p 变换】 # 对 Margin, Entropy 等偏态分布特征进行对数变换,拉伸低值区间的区分度。 #保留符号,对绝对值做 log1p scalar_feats_log = torch.sign(scalar_feats) * torch.log1p(torch.abs(scalar_feats)) # ===================================================================== # 1. 处理标量特征 (应用 TTA LayerNorm) s_feat = self.scalar_ln(scalar_feats_log) # 2. 处理模态特征 m_feat = self.modality_emb(modality_idx) # [B, 4] features = [s_feat, m_feat] # 3. 处理语义特征 (如果有) if self.use_embedding: if qry_emb is None: raise ValueError("Classifier init with embedding_dim > 0 but forward received None") # 确保输入是 float32 if qry_emb.dtype != torch.float32: qry_emb = qry_emb.float() e_feat = self.emb_proj(qry_emb) features.append(e_feat) # 拼接 x = torch.cat(features, dim=1) # MLP 输出 Logits logits = self.mlp(x) # [B, 1] return logits