|
|
from torch import Tensor |
|
|
import torch.distributed as dist |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import os |
|
|
|
|
|
class SimpleContrastiveLoss: |
|
|
def __init__(self, temperature: float = 0.02, alpha: float = 0.05, weights=None): |
|
|
""" |
|
|
weights: list[float] or None |
|
|
- 若提供 weights,则用于对多个视角/层的 CE 加权,长度需与视角数一致,训练时会归一化。 |
|
|
- 若为 None 且 K==2,退化为 [alpha, 1-alpha] |
|
|
- 若为 None 且 K>2,默认均匀权重 |
|
|
""" |
|
|
self.temperature = temperature |
|
|
self.alpha = alpha |
|
|
self.weights = weights |
|
|
|
|
|
def _get_weights(self, K: int, device): |
|
|
if self.weights is not None: |
|
|
assert len(self.weights) == K, f"weights length {len(self.weights)} != K={K}" |
|
|
w = torch.tensor(self.weights, dtype=torch.float32, device=device) |
|
|
w = torch.clamp(w, min=0) |
|
|
s = w.sum().item() |
|
|
if s <= 0: |
|
|
w = torch.ones(K, device=device) / K |
|
|
else: |
|
|
w = w / s |
|
|
return w |
|
|
if K == 2: |
|
|
w = torch.tensor([self.alpha, 1.0 - self.alpha], dtype=torch.float32, device=device) |
|
|
return torch.clamp(w, min=0) / max(w.sum().item(), 1e-8) |
|
|
|
|
|
return torch.ones(K, dtype=torch.float32, device=device) / K |
|
|
|
|
|
def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor: |
|
|
""" |
|
|
统一支持: |
|
|
- x=[B, D], y=[B, D] -> 单视角 |
|
|
- x=[B, K, D], y=[B, D] -> K 个 query 视角对单一候选视角 |
|
|
- x=[B, D], y=[B, K, D] -> 单一 query 视角对 K 个候选视角 |
|
|
- x=[B, K, D], y=[B, K, D] -> 逐视角配对(k↔k)加权 |
|
|
""" |
|
|
B = x.size(0) |
|
|
if target is None: |
|
|
target_per_qry = y.size(0) // B |
|
|
target = torch.arange(0, B * target_per_qry, target_per_qry, device=x.device, dtype=torch.long) |
|
|
|
|
|
|
|
|
if x.dim() == 2 and y.dim() == 2: |
|
|
logits = torch.matmul(x, y.transpose(0, 1)) / self.temperature |
|
|
return F.cross_entropy(logits, target, reduction=reduction) |
|
|
|
|
|
|
|
|
if x.dim() == 3 and y.dim() == 2: |
|
|
K = x.size(1) |
|
|
w = self._get_weights(K, x.device) |
|
|
loss = 0.0 |
|
|
for k in range(K): |
|
|
logits_k = torch.matmul(x[:, k, :], y.transpose(0, 1)) / self.temperature |
|
|
loss_k = F.cross_entropy(logits_k, target, reduction=reduction) |
|
|
loss = loss + w[k] * loss_k |
|
|
return loss |
|
|
|
|
|
|
|
|
if x.dim() == 2 and y.dim() == 3: |
|
|
K = y.size(1) |
|
|
w = self._get_weights(K, x.device) |
|
|
loss = 0.0 |
|
|
for k in range(K): |
|
|
logits_k = torch.matmul(x, y[:, k, :].transpose(0, 1)) / self.temperature |
|
|
loss_k = F.cross_entropy(logits_k, target, reduction=reduction) |
|
|
loss = loss + w[k] * loss_k |
|
|
return loss |
|
|
|
|
|
|
|
|
if x.dim() == 3 and y.dim() == 3: |
|
|
Kx, Ky = x.size(1), y.size(1) |
|
|
assert Kx == Ky, f"view mismatch: {Kx} vs {Ky}" |
|
|
K = Kx |
|
|
w = self._get_weights(K, x.device) |
|
|
loss = 0.0 |
|
|
for k in range(K): |
|
|
logits_k = torch.matmul(x[:, k, :], y[:, k, :].transpose(0, 1)) / self.temperature |
|
|
loss_k = F.cross_entropy(logits_k, target, reduction=reduction) |
|
|
loss = loss + w[k] * loss_k |
|
|
return loss |
|
|
|
|
|
raise ValueError(f"Unsupported shapes: x {tuple(x.size())}, y {tuple(y.size())}") |
|
|
|
|
|
|
|
|
class DistributedContrastiveLoss(SimpleContrastiveLoss): |
|
|
def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02, alpha: float = 0.05, weights=None): |
|
|
assert dist.is_initialized(), "Distributed training has not been properly initialized." |
|
|
super().__init__(temperature=temperature, alpha=alpha, weights=weights) |
|
|
self.word_size = dist.get_world_size() |
|
|
self.rank = dist.get_rank() |
|
|
self.scale_loss = scale_loss |
|
|
|
|
|
def __call__(self, x: Tensor, y: Tensor, **kwargs): |
|
|
dist_x = self.gather_tensor(x) |
|
|
dist_y = self.gather_tensor(y) |
|
|
loss = super().__call__(dist_x, dist_y, **kwargs) |
|
|
if self.scale_loss: |
|
|
loss = loss * self.word_size |
|
|
return loss |
|
|
|
|
|
def gather_tensor(self, t): |
|
|
gathered = [torch.empty_like(t) for _ in range(self.word_size)] |
|
|
dist.all_gather(gathered, t) |
|
|
gathered[self.rank] = t |
|
|
return torch.cat(gathered, dim=0) |
|
|
|
|
|
class InExampleContrastiveLoss: |
|
|
""" |
|
|
保持不变 |
|
|
x.shape=[bsz, hdim], y.shape=[bsz, num_label, hdim] |
|
|
""" |
|
|
def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, ndim: int = None, *args, **kwargs): |
|
|
self.target_per_qry = n_hard_negatives + 1 |
|
|
self.temperature = temperature |
|
|
self.ndim = ndim |
|
|
|
|
|
def __call__(self, x: Tensor, y: Tensor, reduction: str = 'mean'): |
|
|
if torch.distributed.is_initialized(): |
|
|
x = dist_utils.dist_gather(x) |
|
|
y = dist_utils.dist_gather(y) |
|
|
bsz, ndim = x.size(0), x.size(1) |
|
|
target = torch.zeros(bsz, dtype=torch.long, device=x.device) |
|
|
if self.ndim: |
|
|
ndim = self.ndim |
|
|
x = x[:, :ndim] |
|
|
y = y[:, :ndim] |
|
|
logits = torch.einsum('bod,bsd->bs', x.view(bsz, 1, ndim), y.view(bsz, -1, ndim)) * self.temperature |
|
|
preds = torch.argmax(logits, dim=-1) |
|
|
loss = F.cross_entropy(logits, target, reduction=reduction) |
|
|
loss_detail = {"logits": logits, "labels": target, "preds": preds} |
|
|
return loss, loss_detail |