File size: 3,458 Bytes
0a937d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import torch
import torch.nn as nn
from torch.utils.data import Sampler
from collections import defaultdict
import random
import logging
logger = logging.getLogger(__name__)
class EarlyExitClassifier(nn.Module):
def __init__(self, input_dim=27, hidden_dim=64):
"""
input_dim: 与 offline feat_cols_single 中 scalar 特征数量对齐(27)
实际特征为:
[s1_mid, s2_mid, margin_mid, z_margin_mid, mad_margin_mid,
p1_mid, H_mid, gini_mid,
topk_mean, topk_std, topk_cv, topk_kurt, topk_med,
s1_over_mean, s1_over_med,
p1, p2, shape_H, shape_gini, slope, z1, s1_over_sk,
tail_mean, head5_mean,
mask_ratio, mask_len, mask_runs]
再拼接 4 维 modality embedding → total_input_dim = input_dim + 4
"""
super().__init__()
# 0=Text-only, 1=Multimodal(Image+Text)
self.modality_emb = nn.Embedding(2, 4)
total_input_dim = input_dim + 4
self.mlp = nn.Sequential(
nn.Linear(total_input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1), # 输出 logits
)
def forward(self, scalar_feats, modality_idx):
"""
scalar_feats: [B, input_dim]
modality_idx: [B] in {0,1}
"""
mod_feat = self.modality_emb(modality_idx)
x = torch.cat([scalar_feats, mod_feat], dim=1)
logits = self.mlp(x) # [B,1]
return logits
class HomogeneousBatchSampler(Sampler):
"""
按 global_dataset_name 分组,同一 batch 尽量来自同一子数据集,
便于 in-batch 对比学习。
"""
def __init__(self, dataset, batch_size, drop_last=False):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.groups = defaultdict(list)
logger.info("Grouping data by dataset source for Homogeneous Sampling...")
try:
for idx in range(len(dataset)):
item = dataset[idx]
d_name = item.get("global_dataset_name", "unknown")
self.groups[d_name].append(idx)
except Exception as e:
logger.warning(
f"Error grouping dataset: {e}. "
"Falling back to simple index chunking (NOT HOMOGENEOUS)."
)
self.groups["all"] = list(range(len(dataset)))
logger.info(f"Grouped data into {len(self.groups)} datasets.")
def __iter__(self):
batch_list = []
for _, indices in self.groups.items():
random.shuffle(indices)
for i in range(0, len(indices), self.batch_size):
batch = indices[i : i + self.batch_size]
if len(batch) < self.batch_size and self.drop_last:
continue
if len(batch) < 2:
continue
batch_list.append(batch)
random.shuffle(batch_list)
for batch in batch_list:
yield batch
def __len__(self):
count = 0
for indices in self.groups.values():
if self.drop_last:
count += len(indices) // self.batch_size
else:
remainder = len(indices) % self.batch_size
full = len(indices) // self.batch_size
count += full + (1 if remainder >= 2 else 0)
return count |