File size: 3,458 Bytes
0a937d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
from torch.utils.data import Sampler
from collections import defaultdict
import random
import logging

logger = logging.getLogger(__name__)


class EarlyExitClassifier(nn.Module):
    def __init__(self, input_dim=27, hidden_dim=64):
        """
        input_dim: 与 offline feat_cols_single 中 scalar 特征数量对齐(27)
        实际特征为:
          [s1_mid, s2_mid, margin_mid, z_margin_mid, mad_margin_mid,
           p1_mid, H_mid, gini_mid,
           topk_mean, topk_std, topk_cv, topk_kurt, topk_med,
           s1_over_mean, s1_over_med,
           p1, p2, shape_H, shape_gini, slope, z1, s1_over_sk,
           tail_mean, head5_mean,
           mask_ratio, mask_len, mask_runs]
        再拼接 4 维 modality embedding → total_input_dim = input_dim + 4
        """
        super().__init__()
        # 0=Text-only, 1=Multimodal(Image+Text)
        self.modality_emb = nn.Embedding(2, 4)

        total_input_dim = input_dim + 4

        self.mlp = nn.Sequential(
            nn.Linear(total_input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),  # 输出 logits
        )

    def forward(self, scalar_feats, modality_idx):
        """
        scalar_feats: [B, input_dim]
        modality_idx: [B] in {0,1}
        """
        mod_feat = self.modality_emb(modality_idx)
        x = torch.cat([scalar_feats, mod_feat], dim=1)
        logits = self.mlp(x)  # [B,1]
        return logits


class HomogeneousBatchSampler(Sampler):
    """
    按 global_dataset_name 分组,同一 batch 尽量来自同一子数据集,
    便于 in-batch 对比学习。
    """
    def __init__(self, dataset, batch_size, drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.groups = defaultdict(list)

        logger.info("Grouping data by dataset source for Homogeneous Sampling...")
        try:
            for idx in range(len(dataset)):
                item = dataset[idx]
                d_name = item.get("global_dataset_name", "unknown")
                self.groups[d_name].append(idx)
        except Exception as e:
            logger.warning(
                f"Error grouping dataset: {e}. "
                "Falling back to simple index chunking (NOT HOMOGENEOUS)."
            )
            self.groups["all"] = list(range(len(dataset)))

        logger.info(f"Grouped data into {len(self.groups)} datasets.")

    def __iter__(self):
        batch_list = []
        for _, indices in self.groups.items():
            random.shuffle(indices)
            for i in range(0, len(indices), self.batch_size):
                batch = indices[i : i + self.batch_size]
                if len(batch) < self.batch_size and self.drop_last:
                    continue
                if len(batch) < 2:
                    continue
                batch_list.append(batch)

        random.shuffle(batch_list)
        for batch in batch_list:
            yield batch

    def __len__(self):
        count = 0
        for indices in self.groups.values():
            if self.drop_last:
                count += len(indices) // self.batch_size
            else:
                remainder = len(indices) % self.batch_size
                full = len(indices) // self.batch_size
                count += full + (1 if remainder >= 2 else 0)
        return count