| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| def infoNCE_loss1(mol_features, ms_features, temperature=0.1, norm=True):
|
|
|
| if norm:
|
| mol_features = F.normalize(mol_features, p=2, dim=1)
|
| ms_features = F.normalize(ms_features, p=2, dim=1)
|
|
|
|
|
| logits = torch.mm(mol_features, ms_features.T) / temperature
|
|
|
|
|
| batch_size = mol_features.size(0)
|
| labels = torch.arange(batch_size, device=mol_features.device)
|
|
|
|
|
| loss_mol = F.cross_entropy(logits, labels)
|
| loss_trans = F.cross_entropy(logits.T, labels)
|
| loss = (loss_mol + loss_trans) / 2
|
|
|
| return loss
|
|
|
| def infoNCE_loss2(mol_features, ms_features, temperature=0.1, alpha=0.75, norm=True):
|
| """
|
| 使用更合适的temperature (0.07是CLIP中常用的值)
|
| 添加更多的数值稳定性措施
|
| """
|
| if norm:
|
| mol_features = F.normalize(mol_features, p=2, dim=1)
|
| ms_features = F.normalize(ms_features, p=2, dim=1)
|
|
|
| batch_size = mol_features.size(0)
|
|
|
|
|
| logits_ab = torch.matmul(mol_features, ms_features.T) / temperature
|
| logits_ba = torch.matmul(ms_features, mol_features.T) / temperature
|
|
|
|
|
| labels = torch.arange(batch_size, device=mol_features.device)
|
|
|
|
|
| loss_ab = F.cross_entropy(logits_ab, labels)
|
| loss_ba = F.cross_entropy(logits_ba, labels)
|
|
|
| return alpha * loss_ab + (1 - alpha) * loss_ba
|
|
|
|
|
| def contrastive_loss_with_hard_negatives(features1, features2, margin=1.0, hard_negative_ratio=0.3):
|
| """
|
| 改进的对比损失函数,包含困难负样本挖掘
|
| """
|
| batch_size = features1.shape[0]
|
|
|
|
|
| similarity = torch.matmul(features1, features2.t())
|
|
|
|
|
| positive_similarity = torch.diag(similarity)
|
|
|
|
|
| mask = ~torch.eye(batch_size, dtype=torch.bool)
|
| negative_similarities = similarity[mask].view(batch_size, batch_size-1)
|
|
|
|
|
| k = int(batch_size * hard_negative_ratio)
|
| hard_negatives, _ = torch.topk(negative_similarities, k=k, dim=1)
|
|
|
|
|
| loss = 0
|
| for i in range(batch_size):
|
| pos_loss = 1 - positive_similarity[i]
|
| neg_loss = torch.mean(torch.clamp(hard_negatives[i] - margin, min=0))
|
| loss += pos_loss + neg_loss
|
|
|
| return loss / batch_size |