MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
import torch
from torch import Tensor
from torch.nn import functional as F
from torch import distributed as dist
from src import dist_utils
class InExampleContrastiveLoss:
"""
Categorization loss: cross_entropy of 1 out of K classes (target labels)
x.shape=[bsz, hdim], y.shape=[bsz, num_label, hdim]
"""
def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, ndim: int = None, *args, **kwargs):
self.target_per_qry = n_hard_negatives + 1
self.temperature = temperature
self.ndim = ndim
def __call__(self, x: Tensor, y: Tensor, reduction: str = 'mean'):
# print("gather InExampleContrastiveLoss")
if torch.distributed.is_initialized():
x = dist_utils.dist_gather(x)
y = dist_utils.dist_gather(y)
bsz, ndim = x.size(0), x.size(1)
target = torch.zeros(bsz, dtype=torch.long, device=x.device)
if self.ndim:
ndim = self.ndim
x = x[:, :ndim]
y = y[:, :ndim]
logits = torch.einsum('bod,bsd->bs', x.view(bsz, 1, ndim), y.view(bsz, -1, ndim)) * self.temperature
preds = torch.argmax(logits, dim=-1)
loss = F.cross_entropy(logits, target, reduction=reduction)
loss_detail = {"logits": logits, "labels": target, "preds": preds}
return loss, loss_detail
class SimpleContrastiveLoss:
def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, *args, **kwargs):
self.target_per_qry = n_hard_negatives + 1
self.temperature = temperature
def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'):
# print("gather SimpleContrastiveLoss")
if target is None:
assert x.size(0) * self.target_per_qry == y.size(0)
target = torch.arange(0, y.size(0), step=self.target_per_qry, dtype=torch.long, device=x.device)
logits = torch.matmul(x, y.transpose(0, 1)) * self.temperature
preds = torch.argmax(logits, dim=-1)
loss = F.cross_entropy(logits, target, reduction=reduction)
loss_detail = {"logits": logits, "labels": target, "preds": preds}
return loss, loss_detail
class DistributedContrastiveLoss(SimpleContrastiveLoss):
def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, *args, **kwargs):
assert dist.is_initialized(), "Distributed training has not been properly initialized."
super().__init__(n_hard_negatives=n_hard_negatives, temperature=temperature)
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
def __call__(self, x: Tensor, y: Tensor, **kwargs):
# print("gather DistributedContrastiveLoss")
dist_x = self.gather_tensor(x)
dist_y = self.gather_tensor(y)
return super().__call__(dist_x, dist_y, **kwargs)
def gather_tensor(self, t):
gathered = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(gathered, t)
gathered[self.rank] = t
return torch.cat(gathered, dim=0)
LossName2LossCls = {
"inexample_contrastive": InExampleContrastiveLoss,
"inbatch_contrastive": SimpleContrastiveLoss,
"distributed_inbatch_contrastive": DistributedContrastiveLoss,
}