| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from scipy.stats import entropy |
| |
|
| |
|
| | class AdaptiveAugmentation: |
| | """ |
| | Implements adaptive data-driven augmentation for HARCNet. |
| | Dynamically adjusts geometric and MixUp augmentations based on data distribution. |
| | """ |
| | def __init__(self, alpha=0.5, beta=0.5, gamma=2.0): |
| | """ |
| | Args: |
| | alpha: Weight for variance component in geometric augmentation |
| | beta: Weight for entropy component in geometric augmentation |
| | gamma: Scaling factor for MixUp interpolation |
| | """ |
| | self.alpha = alpha |
| | self.beta = beta |
| | self.gamma = gamma |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | def compute_variance(self, x): |
| | """Compute variance across feature dimensions""" |
| | |
| | |
| | var = torch.var(x, dim=1, keepdim=True) |
| | return var.mean(dim=[1, 2, 3]) |
| | |
| | def compute_entropy(self, probs): |
| | """Compute entropy of probability distributions""" |
| | |
| | |
| | probs = torch.clamp(probs, min=1e-8, max=1.0) |
| | log_probs = torch.log(probs) |
| | entropy_val = -torch.sum(probs * log_probs, dim=1) |
| | return entropy_val |
| | |
| | def get_geometric_strength(self, x, model=None, probs=None): |
| | """ |
| | Compute geometric augmentation strength based on sample variance and entropy |
| | S_g(x_i) = 伪路Var(x_i) + 尾路Entropy(x_i) |
| | """ |
| | var = self.compute_variance(x) |
| | |
| | |
| | if probs is None and model is not None: |
| | with torch.no_grad(): |
| | logits = model(x) |
| | probs = F.softmax(logits, dim=1) |
| | |
| | if probs is not None: |
| | ent = self.compute_entropy(probs) |
| | else: |
| | |
| | ent = torch.ones_like(var) |
| | |
| | |
| | var = (var - var.min()) / (var.max() - var.min() + 1e-8) |
| | ent = (ent - ent.min()) / (ent.max() - ent.min() + 1e-8) |
| | |
| | strength = self.alpha * var + self.beta * ent |
| | return strength |
| | |
| | def get_mixup_params(self, y, num_classes=100): |
| | """ |
| | Generate MixUp parameters based on label entropy |
| | 位 ~ Beta(纬路Entropy(y), 纬路Entropy(y)) |
| | """ |
| | |
| | y_onehot = F.one_hot(y, num_classes=num_classes).float() |
| | |
| | |
| | batch_entropy = self.compute_entropy(y_onehot.mean(dim=0, keepdim=True)).item() |
| | |
| | |
| | alpha = self.gamma * batch_entropy |
| | alpha = max(0.1, min(alpha, 2.0)) |
| | |
| | lam = np.random.beta(alpha, alpha) |
| | |
| | |
| | batch_size = y.size(0) |
| | index = torch.randperm(batch_size).to(self.device) |
| | |
| | return lam, index |
| | |
| | def apply_mixup(self, x, y, num_classes=100): |
| | """Apply MixUp augmentation with adaptive coefficient""" |
| | lam, index = self.get_mixup_params(y, num_classes) |
| | mixed_x = lam * x + (1 - lam) * x[index] |
| | y_a, y_b = y, y[index] |
| | return mixed_x, y_a, y_b, lam |
| |
|
| |
|
| | class TemporalConsistencyRegularization: |
| | """ |
| | Implements decayed temporal consistency regularization for HARCNet. |
| | Reduces noise in pseudo-labels by incorporating past predictions. |
| | """ |
| | def __init__(self, memory_size=5, decay_rate=2.0, consistency_weight=0.1): |
| | """ |
| | Args: |
| | memory_size: Number of past predictions to store (K) |
| | decay_rate: Controls the decay of weights for past predictions (蟿) |
| | consistency_weight: Weight for consistency loss (位_consistency) |
| | """ |
| | self.memory_size = memory_size |
| | self.decay_rate = decay_rate |
| | self.consistency_weight = consistency_weight |
| | self.prediction_history = {} |
| | |
| | def compute_decay_weights(self): |
| | """ |
| | Compute exponentially decaying weights |
| | 蠅_k = e^(-k/蟿) / 危(e^(-k/蟿)) |
| | """ |
| | weights = torch.exp(-torch.arange(1, self.memory_size + 1) / self.decay_rate) |
| | return weights / weights.sum() |
| | |
| | def update_history(self, indices, predictions): |
| | """Update prediction history for each sample""" |
| | for i, idx in enumerate(indices): |
| | idx = idx.item() |
| | if idx not in self.prediction_history: |
| | self.prediction_history[idx] = [] |
| | |
| | |
| | self.prediction_history[idx].append(predictions[i].detach()) |
| | |
| | |
| | if len(self.prediction_history[idx]) > self.memory_size: |
| | self.prediction_history[idx].pop(0) |
| | |
| | def get_aggregated_predictions(self, indices): |
| | """ |
| | Get aggregated predictions for each sample using decay weights |
| | 峄筥i = 危(蠅_k 路 欧_i^(t-k)) |
| | """ |
| | weights = self.compute_decay_weights().to(indices.device) |
| | aggregated_preds = [] |
| | |
| | for i, idx in enumerate(indices): |
| | idx = idx.item() |
| | if idx in self.prediction_history and len(self.prediction_history[idx]) > 0: |
| | |
| | history = self.prediction_history[idx] |
| | history_len = len(history) |
| | |
| | if history_len > 0: |
| | |
| | available_weights = weights[-history_len:] |
| | available_weights = available_weights / available_weights.sum() |
| | |
| | |
| | weighted_sum = torch.zeros_like(history[0]) |
| | for j, pred in enumerate(history): |
| | weighted_sum += available_weights[j] * pred |
| | |
| | aggregated_preds.append(weighted_sum) |
| | else: |
| | |
| | aggregated_preds.append(torch.zeros_like(history[0])) |
| | else: |
| | |
| | aggregated_preds.append(None) |
| | |
| | return aggregated_preds |
| | |
| | def compute_consistency_loss(self, current_preds, indices): |
| | """ |
| | Compute consistency loss between current and aggregated past predictions |
| | L_consistency(x_i) = ||欧_i^(t) - 危(蠅_k 路 欧_i^(t-k))||^2_2 |
| | """ |
| | aggregated_preds = self.get_aggregated_predictions(indices) |
| | loss = 0.0 |
| | valid_samples = 0 |
| | |
| | for i, agg_pred in enumerate(aggregated_preds): |
| | if agg_pred is not None: |
| | |
| | sample_loss = F.mse_loss(current_preds[i], agg_pred) |
| | loss += sample_loss |
| | valid_samples += 1 |
| | |
| | |
| | if valid_samples > 0: |
| | return loss / valid_samples |
| | else: |
| | |
| | return torch.tensor(0.0).to(current_preds.device) |
| |
|