File size: 3,780 Bytes
0a937d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
from torch.utils.data import Sampler
from collections import defaultdict
import random
import logging

logger = logging.getLogger(__name__)

class EarlyExitClassifier(nn.Module):
    def __init__(self, input_dim=5, hidden_dim=64):
        """
        input_dim=5: [Top1_Score, Margin, Entropy, Norm, Variance]
        """
        super().__init__()
        # 模态 Embedding: 0=Text-Only, 1=Multimodal
        self.modality_emb = nn.Embedding(2, 4)
        
        total_input_dim = input_dim + 4
        
        self.mlp = nn.Sequential(
            nn.Linear(total_input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, scalar_feats, modality_idx):
        mod_feat = self.modality_emb(modality_idx)
        x = torch.cat([scalar_feats, mod_feat], dim=1)
        return self.mlp(x)

class HomogeneousBatchSampler(Sampler):
    def __init__(self, dataset, batch_size, drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.groups = defaultdict(list)

        logger.info("Grouping data by dataset source for Homogeneous Sampling...")
        try:
            # 尝试访问 MixedDataset 的 data 属性
            if hasattr(dataset, 'datasets'):
                # 如果是 ConcatDataset 或 MixedDataset
                current_idx = 0
                for sub_ds in dataset.datasets:
                    # 假设每个子数据集内部是同源的
                    # 这里做一个简化假设:MixedDataset 通常由多个子数据集组成
                    # 我们需要找到一种方法获取 global_dataset_name
                    # 如果数据量太大,这里建议改为在 collator 里做 check,或者预先处理好索引文件
                    # 这里为了演示,假设 dataset[i] 很快(对于 Lazy Loading 还可以)
                    pass
            
            # 简单遍历策略 (如果数据集在内存中)
            # 注意:如果数据集极大且是 Lazy Loading,这步初始化会很慢
            # 建议:在 DataArguments 里传入一个预先计算好的 dataset_indices.json
            for idx in range(len(dataset)):
                item = dataset[idx]
                d_name = item.get('global_dataset_name', 'unknown')
                self.groups[d_name].append(idx)
                
        except Exception as e:
            logger.warning(f"Error grouping dataset: {e}. Falling back to simple index chunking (NOT HOMOGENEOUS).")
            self.groups['all'] = list(range(len(dataset)))

        logger.info(f"Grouped data into {len(self.groups)} datasets.")

    def __iter__(self):
        batch_list = []
        for d_name, indices in self.groups.items():
            random.shuffle(indices)
            for i in range(0, len(indices), self.batch_size):
                batch = indices[i : i + self.batch_size]
                if len(batch) < self.batch_size and self.drop_last:
                    continue
                if len(batch) < 2: # 丢弃太小的 batch,无法做对比学习
                    continue
                batch_list.append(batch)
        
        random.shuffle(batch_list)
        for batch in batch_list:
            yield batch

    def __len__(self):
        count = 0
        for indices in self.groups.values():
            if self.drop_last:
                count += len(indices) // self.batch_size
            else:
                remainder = len(indices) % self.batch_size
                full = len(indices) // self.batch_size
                count += full + (1 if remainder >= 2 else 0)
        return count