File size: 3,780 Bytes
0a937d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import torch
import torch.nn as nn
from torch.utils.data import Sampler
from collections import defaultdict
import random
import logging
logger = logging.getLogger(__name__)
class EarlyExitClassifier(nn.Module):
def __init__(self, input_dim=5, hidden_dim=64):
"""
input_dim=5: [Top1_Score, Margin, Entropy, Norm, Variance]
"""
super().__init__()
# 模态 Embedding: 0=Text-Only, 1=Multimodal
self.modality_emb = nn.Embedding(2, 4)
total_input_dim = input_dim + 4
self.mlp = nn.Sequential(
nn.Linear(total_input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
def forward(self, scalar_feats, modality_idx):
mod_feat = self.modality_emb(modality_idx)
x = torch.cat([scalar_feats, mod_feat], dim=1)
return self.mlp(x)
class HomogeneousBatchSampler(Sampler):
def __init__(self, dataset, batch_size, drop_last=False):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.groups = defaultdict(list)
logger.info("Grouping data by dataset source for Homogeneous Sampling...")
try:
# 尝试访问 MixedDataset 的 data 属性
if hasattr(dataset, 'datasets'):
# 如果是 ConcatDataset 或 MixedDataset
current_idx = 0
for sub_ds in dataset.datasets:
# 假设每个子数据集内部是同源的
# 这里做一个简化假设:MixedDataset 通常由多个子数据集组成
# 我们需要找到一种方法获取 global_dataset_name
# 如果数据量太大,这里建议改为在 collator 里做 check,或者预先处理好索引文件
# 这里为了演示,假设 dataset[i] 很快(对于 Lazy Loading 还可以)
pass
# 简单遍历策略 (如果数据集在内存中)
# 注意:如果数据集极大且是 Lazy Loading,这步初始化会很慢
# 建议:在 DataArguments 里传入一个预先计算好的 dataset_indices.json
for idx in range(len(dataset)):
item = dataset[idx]
d_name = item.get('global_dataset_name', 'unknown')
self.groups[d_name].append(idx)
except Exception as e:
logger.warning(f"Error grouping dataset: {e}. Falling back to simple index chunking (NOT HOMOGENEOUS).")
self.groups['all'] = list(range(len(dataset)))
logger.info(f"Grouped data into {len(self.groups)} datasets.")
def __iter__(self):
batch_list = []
for d_name, indices in self.groups.items():
random.shuffle(indices)
for i in range(0, len(indices), self.batch_size):
batch = indices[i : i + self.batch_size]
if len(batch) < self.batch_size and self.drop_last:
continue
if len(batch) < 2: # 丢弃太小的 batch,无法做对比学习
continue
batch_list.append(batch)
random.shuffle(batch_list)
for batch in batch_list:
yield batch
def __len__(self):
count = 0
for indices in self.groups.values():
if self.drop_last:
count += len(indices) // self.batch_size
else:
remainder = len(indices) % self.batch_size
full = len(indices) // self.batch_size
count += full + (1 if remainder >= 2 else 0)
return count |