File size: 2,404 Bytes
0a937d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
import logging

logger = logging.getLogger(__name__)

class EarlyExitClassifier(nn.Module):
    def __init__(self, input_dim=27, hidden_dim=128, embedding_dim=0):
        """
        Args:
            input_dim: 统计特征维度 (27)
            hidden_dim: 隐藏层大小
            embedding_dim: 如果 > 0,则接收 backbone 的 hidden_states 作为输入
        """
        super().__init__()
        
        # === 改进点1: 输入归一化 (关键!) ===
        # 这能解决不同任务 Score 分布不一样导致阈值无法统一的问题
        self.scalar_bn = nn.BatchNorm1d(input_dim)
        
        # 模态 Embedding
        self.modality_emb = nn.Embedding(2, 4)

        # === 改进点2: 引入语义 Embedding ===
        self.use_embedding = embedding_dim > 0
        if self.use_embedding:
            # 将高维 Embedding (如 2560) 压缩,避免它主导整个网络
            self.emb_proj = nn.Sequential(
                nn.Linear(embedding_dim, 64),
                nn.LayerNorm(64),
                nn.ReLU()
            )
            # 总维度 = 统计特征(27) + 模态(4) + 语义压缩(64)
            total_input_dim = input_dim + 4 + 64
        else:
            total_input_dim = input_dim + 4

        # MLP 主体
        self.mlp = nn.Sequential(
            nn.Linear(total_input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), # 中间层也加 BN 防止梯度消失
            nn.ReLU(),
            nn.Dropout(0.2),            # 增加 Dropout 防止过拟合
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, scalar_feats, modality_idx, qry_emb=None):
        """
        scalar_feats: [B, 27]
        modality_idx: [B]
        qry_emb: [B, hidden_size] (New!)
        """
        # 1. 归一化统计特征
        s_feat = self.scalar_bn(scalar_feats)
        
        # 2. 模态特征
        m_feat = self.modality_emb(modality_idx)
        
        features = [s_feat, m_feat]

        # 3. 处理语义特征
        if self.use_embedding:
            if qry_emb is None:
                raise ValueError("Model initialized with embedding_dim > 0 but qry_emb is None")
            e_feat = self.emb_proj(qry_emb)
            features.append(e_feat)
        
        # 拼接
        x = torch.cat(features, dim=1)
        logits = self.mlp(x)
        return logits