import torch import torch.nn as nn from torch.utils.data import Sampler from collections import defaultdict import random import logging logger = logging.getLogger(__name__) class EarlyExitClassifier(nn.Module): def __init__(self, input_dim=27, hidden_dim=64): """ input_dim: 与 offline feat_cols_single 中 scalar 特征数量对齐(27) 实际特征为: [s1_mid, s2_mid, margin_mid, z_margin_mid, mad_margin_mid, p1_mid, H_mid, gini_mid, topk_mean, topk_std, topk_cv, topk_kurt, topk_med, s1_over_mean, s1_over_med, p1, p2, shape_H, shape_gini, slope, z1, s1_over_sk, tail_mean, head5_mean, mask_ratio, mask_len, mask_runs] 再拼接 4 维 modality embedding → total_input_dim = input_dim + 4 """ super().__init__() # 0=Text-only, 1=Multimodal(Image+Text) self.modality_emb = nn.Embedding(2, 4) total_input_dim = input_dim + 4 self.mlp = nn.Sequential( nn.Linear(total_input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), # 输出 logits ) def forward(self, scalar_feats, modality_idx): """ scalar_feats: [B, input_dim] modality_idx: [B] in {0,1} """ mod_feat = self.modality_emb(modality_idx) x = torch.cat([scalar_feats, mod_feat], dim=1) logits = self.mlp(x) # [B,1] return logits class HomogeneousBatchSampler(Sampler): """ 按 global_dataset_name 分组,同一 batch 尽量来自同一子数据集, 便于 in-batch 对比学习。 """ def __init__(self, dataset, batch_size, drop_last=False): self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last self.groups = defaultdict(list) logger.info("Grouping data by dataset source for Homogeneous Sampling...") try: for idx in range(len(dataset)): item = dataset[idx] d_name = item.get("global_dataset_name", "unknown") self.groups[d_name].append(idx) except Exception as e: logger.warning( f"Error grouping dataset: {e}. " "Falling back to simple index chunking (NOT HOMOGENEOUS)." ) self.groups["all"] = list(range(len(dataset))) logger.info(f"Grouped data into {len(self.groups)} datasets.") def __iter__(self): batch_list = [] for _, indices in self.groups.items(): random.shuffle(indices) for i in range(0, len(indices), self.batch_size): batch = indices[i : i + self.batch_size] if len(batch) < self.batch_size and self.drop_last: continue if len(batch) < 2: continue batch_list.append(batch) random.shuffle(batch_list) for batch in batch_list: yield batch def __len__(self): count = 0 for indices in self.groups.values(): if self.drop_last: count += len(indices) // self.batch_size else: remainder = len(indices) % self.batch_size full = len(indices) // self.batch_size count += full + (1 if remainder >= 2 else 0) return count