File size: 3,800 Bytes
0a937d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import torch.nn.functional as F

class EarlyExitClassifier(nn.Module):
    """
    V5 版本分类器:集成轻量级 TTA (LayerNorm) 和 Log1p 特征变换
    """
    def __init__(self, input_dim=27, hidden_dim=128, embedding_dim=0, dropout_prob=0.2):
        super().__init__()
        
        # =====================================================================
        # 【改动 1:TTA 核心实现 - LayerNorm】
        # 使用 LayerNorm 替代 BatchNorm。
        # LayerNorm 在推理时会对每个样本独立计算统计量进行归一化,
        # 从而适应不同数据集的特征分布差异 (即 Test-Time Adaptation)。
        # =====================================================================
        self.scalar_ln = nn.LayerNorm(input_dim)
        
        self.modality_emb = nn.Embedding(2, 4) 

        self.use_embedding = embedding_dim > 0
        if self.use_embedding:
            # 语义特征投影层
            self.emb_proj = nn.Sequential(
                nn.Linear(embedding_dim, hidden_dim // 2),
                nn.LayerNorm(hidden_dim // 2), # 这里也用 LN 保持一致性
                nn.ReLU()
            )
            # scalar(经过LN) + modality(4) + embedding(hidden/2)
            total_input_dim = input_dim + 4 + (hidden_dim // 2)
        else:
            total_input_dim = input_dim + 4

        # 主 MLP
        self.mlp = nn.Sequential(
            nn.Linear(total_input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim), # 保持使用 LN
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_dim // 2, 1),
        )
        
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None: nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, mean=0, std=0.02)

    def forward(self, scalar_feats, modality_idx, qry_emb=None):
        # scalar_feats: [B, input_dim] (should be Float32)
        
        # =====================================================================
        # 【改动 2:特征预处理 - Log1p 变换】
        # 对 Margin, Entropy 等偏态分布特征进行对数变换,拉伸低值区间的区分度。
        #保留符号,对绝对值做 log1p
        scalar_feats_log = torch.sign(scalar_feats) * torch.log1p(torch.abs(scalar_feats))
        # =====================================================================

        # 1. 处理标量特征 (应用 TTA LayerNorm)
        s_feat = self.scalar_ln(scalar_feats_log)
        
        # 2. 处理模态特征
        m_feat = self.modality_emb(modality_idx) # [B, 4]
        
        features = [s_feat, m_feat]

        # 3. 处理语义特征 (如果有)
        if self.use_embedding:
            if qry_emb is None:
                raise ValueError("Classifier init with embedding_dim > 0 but forward received None")
            # 确保输入是 float32
            if qry_emb.dtype != torch.float32:
                qry_emb = qry_emb.float()
            e_feat = self.emb_proj(qry_emb)
            features.append(e_feat)
        
        # 拼接
        x = torch.cat(features, dim=1)
        
        # MLP 输出 Logits
        logits = self.mlp(x) # [B, 1]
        return logits