brain / model.py
santacl's picture
Upload 2 files
4140e34 verified
import torch
import torch.nn as nn
import timm
class BRAINHybridLoss(nn.Module):
def __init__(self, margin=1.0):
super(BRAINHybridLoss, self).__init__()
self.margin = margin
self.ce_loss = nn.CrossEntropyLoss()
def forward(self, logits, labels, embeddings):
loss_ce = self.ce_loss(logits, labels)
dist_mat = torch.cdist(embeddings, embeddings, p=2.0)
labels_reshaped = labels.unsqueeze(1)
mask_pos = (labels_reshaped == labels_reshaped.T).float()
mask_neg = (labels_reshaped != labels_reshaped.T).float()
negative_dist = dist_mat + (1.0 - mask_neg) * 1e6
hard_negatives = negative_dist.min(dim=1)[0]
positive_dist = dist_mat * mask_pos
hard_positives = positive_dist.max(dim=1)[0]
loss_triplet = F.relu(hard_positives - hard_negatives + self.margin).mean()
return loss_ce + (0.1 * loss_triplet)