import torch import torch.nn as nn from torch.utils.data import Sampler from collections import defaultdict import random import logging logger = logging.getLogger(__name__) class EarlyExitClassifier(nn.Module): def __init__(self, input_dim=5, hidden_dim=64): """ input_dim=5: [Top1_Score, Margin, Entropy, Norm, Variance] """ super().__init__() # 模态 Embedding: 0=Text-Only, 1=Multimodal self.modality_emb = nn.Embedding(2, 4) total_input_dim = input_dim + 4 self.mlp = nn.Sequential( nn.Linear(total_input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), nn.Sigmoid() ) def forward(self, scalar_feats, modality_idx): mod_feat = self.modality_emb(modality_idx) x = torch.cat([scalar_feats, mod_feat], dim=1) return self.mlp(x) class HomogeneousBatchSampler(Sampler): def __init__(self, dataset, batch_size, drop_last=False): self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last self.groups = defaultdict(list) logger.info("Grouping data by dataset source for Homogeneous Sampling...") try: # 尝试访问 MixedDataset 的 data 属性 if hasattr(dataset, 'datasets'): # 如果是 ConcatDataset 或 MixedDataset current_idx = 0 for sub_ds in dataset.datasets: # 假设每个子数据集内部是同源的 # 这里做一个简化假设:MixedDataset 通常由多个子数据集组成 # 我们需要找到一种方法获取 global_dataset_name # 如果数据量太大,这里建议改为在 collator 里做 check,或者预先处理好索引文件 # 这里为了演示,假设 dataset[i] 很快(对于 Lazy Loading 还可以) pass # 简单遍历策略 (如果数据集在内存中) # 注意:如果数据集极大且是 Lazy Loading,这步初始化会很慢 # 建议:在 DataArguments 里传入一个预先计算好的 dataset_indices.json for idx in range(len(dataset)): item = dataset[idx] d_name = item.get('global_dataset_name', 'unknown') self.groups[d_name].append(idx) except Exception as e: logger.warning(f"Error grouping dataset: {e}. Falling back to simple index chunking (NOT HOMOGENEOUS).") self.groups['all'] = list(range(len(dataset))) logger.info(f"Grouped data into {len(self.groups)} datasets.") def __iter__(self): batch_list = [] for d_name, indices in self.groups.items(): random.shuffle(indices) for i in range(0, len(indices), self.batch_size): batch = indices[i : i + self.batch_size] if len(batch) < self.batch_size and self.drop_last: continue if len(batch) < 2: # 丢弃太小的 batch,无法做对比学习 continue batch_list.append(batch) random.shuffle(batch_list) for batch in batch_list: yield batch def __len__(self): count = 0 for indices in self.groups.values(): if self.drop_last: count += len(indices) // self.batch_size else: remainder = len(indices) % self.batch_size full = len(indices) // self.batch_size count += full + (1 if remainder >= 2 else 0) return count