| | from torch import Tensor |
| | import torch.distributed as dist |
| | import torch |
| | import torch.nn.functional as F |
| | import os |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | class SimpleContrastiveLoss: |
| | def __init__(self, temperature: float = 0.02, alpha: float = 0.05): |
| | self.temperature = temperature |
| | self.alpha = alpha |
| |
|
| | def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor: |
| | """ |
| | - 常规:x=[B, D], y=[B, D] -> 单向 CE(InfoNCE) |
| | - 扩展: |
| | * x=[B, 2, D], y=[B, D] -> 两个 query 视角各自对 y 做 CE,并按 alpha 加权 |
| | * x=[B, D], y=[B, 2, D] -> x 分别对 y 的两个视角做 CE,并按 alpha 加权 |
| | * x=[B, 2, D], y=[B, 2, D] -> 匹配视角(0↔0, 1↔1)各算 CE 加权求和 |
| | """ |
| | B = x.size(0) |
| | if target is None: |
| | target_per_qry = y.size(0) // B |
| | target = torch.arange( |
| | 0, B * target_per_qry, target_per_qry, device=x.device, dtype=torch.long |
| | ) |
| |
|
| | |
| | if x.dim() == 2 and y.dim() == 2: |
| | logits = torch.matmul(x, y.transpose(0, 1)) / self.temperature |
| | loss = F.cross_entropy(logits, target, reduction=reduction) |
| | return loss |
| |
|
| | |
| | if x.dim() == 3 and y.dim() == 2: |
| | assert x.size(1) == 2, f"Expect x=[B,2,D], got {tuple(x.size())}" |
| | w0, w1 = self.alpha, 1.0 - self.alpha |
| | q0, q1 = x[:, 0, :], x[:, 1, :] |
| | logits0 = torch.matmul(q0, y.transpose(0, 1)) / self.temperature |
| | logits1 = torch.matmul(q1, y.transpose(0, 1)) / self.temperature |
| | loss0 = F.cross_entropy(logits0, target, reduction=reduction) |
| | loss1 = F.cross_entropy(logits1, target, reduction=reduction) |
| | return w0 * loss0 + w1 * loss1 |
| |
|
| | |
| | if x.dim() == 2 and y.dim() == 3: |
| | assert y.size(1) == 2, f"Expect y=[B,2,D], got {tuple(y.size())}" |
| | w0, w1 = self.alpha, 1.0 - self.alpha |
| | p0, p1 = y[:, 0, :], y[:, 1, :] |
| | logits0 = torch.matmul(x, p0.transpose(0, 1)) / self.temperature |
| | logits1 = torch.matmul(x, p1.transpose(0, 1)) / self.temperature |
| | loss0 = F.cross_entropy(logits0, target, reduction=reduction) |
| | loss1 = F.cross_entropy(logits1, target, reduction=reduction) |
| | return w0 * loss0 + w1 * loss1 |
| |
|
| | |
| | if x.dim() == 3 and y.dim() == 3: |
| | assert x.size(1) == y.size(1) == 2, f"Expect x,y=[B,2,D], got {tuple(x.size())}, {tuple(y.size())}" |
| | w0, w1 = self.alpha, 1.0 - self.alpha |
| | q0, q1 = x[:, 0, :], x[:, 1, :] |
| | p0, p1 = y[:, 0, :], y[:, 1, :] |
| | logits0 = torch.matmul(q0, p0.transpose(0, 1)) / self.temperature |
| | logits1 = torch.matmul(q1, p1.transpose(0, 1)) / self.temperature |
| | loss0 = F.cross_entropy(logits0, target, reduction=reduction) |
| | loss1 = F.cross_entropy(logits1, target, reduction=reduction) |
| | return w0 * loss0 + w1 * loss1 |
| |
|
| | raise ValueError(f"Unsupported shapes: x {tuple(x.size())}, y {tuple(y.size())}") |
| |
|
| |
|
| | class DistributedContrastiveLoss(SimpleContrastiveLoss): |
| | def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02, alpha: float = 0.05): |
| | assert dist.is_initialized(), "Distributed training has not been properly initialized." |
| | super().__init__(temperature=temperature, alpha=alpha) |
| | self.word_size = dist.get_world_size() |
| | self.rank = dist.get_rank() |
| | self.scale_loss = scale_loss |
| |
|
| | def __call__(self, x: Tensor, y: Tensor, **kwargs): |
| | dist_x = self.gather_tensor(x) |
| | dist_y = self.gather_tensor(y) |
| | loss = super().__call__(dist_x, dist_y, **kwargs) |
| | if self.scale_loss: |
| | loss = loss * self.word_size |
| | return loss |
| |
|
| | def gather_tensor(self, t): |
| | gathered = [torch.empty_like(t) for _ in range(self.word_size)] |
| | dist.all_gather(gathered, t) |
| | gathered[self.rank] = t |
| | return torch.cat(gathered, dim=0) |
| |
|
| | class InExampleContrastiveLoss: |
| | """ |
| | 保持不变 |
| | x.shape=[bsz, hdim], y.shape=[bsz, num_label, hdim] |
| | """ |
| | def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, ndim: int = None, *args, **kwargs): |
| | self.target_per_qry = n_hard_negatives + 1 |
| | self.temperature = temperature |
| | self.ndim = ndim |
| |
|
| | def __call__(self, x: Tensor, y: Tensor, reduction: str = 'mean'): |
| | if torch.distributed.is_initialized(): |
| | x = dist_utils.dist_gather(x) |
| | y = dist_utils.dist_gather(y) |
| | bsz, ndim = x.size(0), x.size(1) |
| | target = torch.zeros(bsz, dtype=torch.long, device=x.device) |
| | if self.ndim: |
| | ndim = self.ndim |
| | x = x[:, :ndim] |
| | y = y[:, :ndim] |
| | logits = torch.einsum('bod,bsd->bs', x.view(bsz, 1, ndim), y.view(bsz, -1, ndim)) * self.temperature |
| | preds = torch.argmax(logits, dim=-1) |
| | loss = F.cross_entropy(logits, target, reduction=reduction) |
| | loss_detail = {"logits": logits, "labels": target, "preds": preds} |
| | return loss, loss_detail |