import torch import torch.nn as nn import timm class BRAINHybridLoss(nn.Module): def __init__(self, margin=1.0): super(BRAINHybridLoss, self).__init__() self.margin = margin self.ce_loss = nn.CrossEntropyLoss() def forward(self, logits, labels, embeddings): loss_ce = self.ce_loss(logits, labels) dist_mat = torch.cdist(embeddings, embeddings, p=2.0) labels_reshaped = labels.unsqueeze(1) mask_pos = (labels_reshaped == labels_reshaped.T).float() mask_neg = (labels_reshaped != labels_reshaped.T).float() negative_dist = dist_mat + (1.0 - mask_neg) * 1e6 hard_negatives = negative_dist.min(dim=1)[0] positive_dist = dist_mat * mask_pos hard_positives = positive_dist.max(dim=1)[0] loss_triplet = F.relu(hard_positives - hard_negatives + self.margin).mean() return loss_ce + (0.1 * loss_triplet)