|
|
| import torch |
| import torch.nn as nn |
| import timm |
|
|
| class BRAINHybridLoss(nn.Module): |
| def __init__(self, margin=1.0): |
| super(BRAINHybridLoss, self).__init__() |
| self.margin = margin |
| self.ce_loss = nn.CrossEntropyLoss() |
|
|
| def forward(self, logits, labels, embeddings): |
| loss_ce = self.ce_loss(logits, labels) |
| |
| dist_mat = torch.cdist(embeddings, embeddings, p=2.0) |
| |
| labels_reshaped = labels.unsqueeze(1) |
| mask_pos = (labels_reshaped == labels_reshaped.T).float() |
| mask_neg = (labels_reshaped != labels_reshaped.T).float() |
| |
| negative_dist = dist_mat + (1.0 - mask_neg) * 1e6 |
| hard_negatives = negative_dist.min(dim=1)[0] |
| |
| positive_dist = dist_mat * mask_pos |
| hard_positives = positive_dist.max(dim=1)[0] |
| |
| loss_triplet = F.relu(hard_positives - hard_negatives + self.margin).mean() |
| |
| return loss_ce + (0.1 * loss_triplet) |
|
|