| | import torch |
| | from torch import Tensor |
| | from torch.nn import functional as F |
| | from torch import distributed as dist |
| |
|
| | from src import dist_utils |
| |
|
| |
|
| | class InExampleContrastiveLoss: |
| | """ |
| | Categorization loss: cross_entropy of 1 out of K classes (target labels) |
| | x.shape=[bsz, hdim], y.shape=[bsz, num_label, hdim] |
| | """ |
| | def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, ndim: int = None, *args, **kwargs): |
| | self.target_per_qry = n_hard_negatives + 1 |
| | self.temperature = temperature |
| | self.ndim = ndim |
| |
|
| | def __call__(self, x: Tensor, y: Tensor, reduction: str = 'mean'): |
| | |
| | if torch.distributed.is_initialized(): |
| | x = dist_utils.dist_gather(x) |
| | y = dist_utils.dist_gather(y) |
| | bsz, ndim = x.size(0), x.size(1) |
| | target = torch.zeros(bsz, dtype=torch.long, device=x.device) |
| | if self.ndim: |
| | ndim = self.ndim |
| | x = x[:, :ndim] |
| | y = y[:, :ndim] |
| | logits = torch.einsum('bod,bsd->bs', x.view(bsz, 1, ndim), y.view(bsz, -1, ndim)) * self.temperature |
| | preds = torch.argmax(logits, dim=-1) |
| | loss = F.cross_entropy(logits, target, reduction=reduction) |
| | loss_detail = {"logits": logits, "labels": target, "preds": preds} |
| | return loss, loss_detail |
| |
|
| |
|
| | class SimpleContrastiveLoss: |
| | def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, *args, **kwargs): |
| | self.target_per_qry = n_hard_negatives + 1 |
| | self.temperature = temperature |
| |
|
| | def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'): |
| | |
| | if target is None: |
| | assert x.size(0) * self.target_per_qry == y.size(0) |
| | target = torch.arange(0, y.size(0), step=self.target_per_qry, dtype=torch.long, device=x.device) |
| | logits = torch.matmul(x, y.transpose(0, 1)) * self.temperature |
| | preds = torch.argmax(logits, dim=-1) |
| | loss = F.cross_entropy(logits, target, reduction=reduction) |
| | loss_detail = {"logits": logits, "labels": target, "preds": preds} |
| | return loss, loss_detail |
| |
|
| |
|
| | class DistributedContrastiveLoss(SimpleContrastiveLoss): |
| | def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, *args, **kwargs): |
| | assert dist.is_initialized(), "Distributed training has not been properly initialized." |
| |
|
| | super().__init__(n_hard_negatives=n_hard_negatives, temperature=temperature) |
| | self.world_size = dist.get_world_size() |
| | self.rank = dist.get_rank() |
| |
|
| | def __call__(self, x: Tensor, y: Tensor, **kwargs): |
| | |
| | dist_x = self.gather_tensor(x) |
| | dist_y = self.gather_tensor(y) |
| |
|
| | return super().__call__(dist_x, dist_y, **kwargs) |
| |
|
| | def gather_tensor(self, t): |
| | gathered = [torch.empty_like(t) for _ in range(self.world_size)] |
| | dist.all_gather(gathered, t) |
| | gathered[self.rank] = t |
| | return torch.cat(gathered, dim=0) |
| |
|
| |
|
| | LossName2LossCls = { |
| | "inexample_contrastive": InExampleContrastiveLoss, |
| | "inbatch_contrastive": SimpleContrastiveLoss, |
| | "distributed_inbatch_contrastive": DistributedContrastiveLoss, |
| | } |