import torch import torch.nn as nn import logging logger = logging.getLogger(__name__) class EarlyExitClassifier(nn.Module): def __init__(self, input_dim=27, hidden_dim=128, embedding_dim=0): """ Args: input_dim: 统计特征维度 (27) hidden_dim: 隐藏层大小 embedding_dim: 如果 > 0,则接收 backbone 的 hidden_states 作为输入 """ super().__init__() # === 改进点1: 输入归一化 (关键!) === # 这能解决不同任务 Score 分布不一样导致阈值无法统一的问题 self.scalar_bn = nn.BatchNorm1d(input_dim) # 模态 Embedding self.modality_emb = nn.Embedding(2, 4) # === 改进点2: 引入语义 Embedding === self.use_embedding = embedding_dim > 0 if self.use_embedding: # 将高维 Embedding (如 2560) 压缩,避免它主导整个网络 self.emb_proj = nn.Sequential( nn.Linear(embedding_dim, 64), nn.LayerNorm(64), nn.ReLU() ) # 总维度 = 统计特征(27) + 模态(4) + 语义压缩(64) total_input_dim = input_dim + 4 + 64 else: total_input_dim = input_dim + 4 # MLP 主体 self.mlp = nn.Sequential( nn.Linear(total_input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), # 中间层也加 BN 防止梯度消失 nn.ReLU(), nn.Dropout(0.2), # 增加 Dropout 防止过拟合 nn.Linear(hidden_dim, 1), ) def forward(self, scalar_feats, modality_idx, qry_emb=None): """ scalar_feats: [B, 27] modality_idx: [B] qry_emb: [B, hidden_size] (New!) """ # 1. 归一化统计特征 s_feat = self.scalar_bn(scalar_feats) # 2. 模态特征 m_feat = self.modality_emb(modality_idx) features = [s_feat, m_feat] # 3. 处理语义特征 if self.use_embedding: if qry_emb is None: raise ValueError("Model initialized with embedding_dim > 0 but qry_emb is None") e_feat = self.emb_proj(qry_emb) features.append(e_feat) # 拼接 x = torch.cat(features, dim=1) logits = self.mlp(x) return logits