File size: 5,909 Bytes
0a937d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from torch import Tensor
import torch.distributed as dist
import torch
import torch.nn.functional as F
import os # 新增
class SimpleContrastiveLoss:
def __init__(self, temperature: float = 0.02, alpha: float = 0.05, weights=None):
"""
weights: list[float] or None
- 若提供 weights,则用于对多个视角/层的 CE 加权,长度需与视角数一致,训练时会归一化。
- 若为 None 且 K==2,退化为 [alpha, 1-alpha]
- 若为 None 且 K>2,默认均匀权重
"""
self.temperature = temperature
self.alpha = alpha
self.weights = weights # e.g. [0.1, 0.2, 0.7]
def _get_weights(self, K: int, device):
if self.weights is not None:
assert len(self.weights) == K, f"weights length {len(self.weights)} != K={K}"
w = torch.tensor(self.weights, dtype=torch.float32, device=device)
w = torch.clamp(w, min=0)
s = w.sum().item()
if s <= 0:
w = torch.ones(K, device=device) / K
else:
w = w / s
return w
if K == 2:
w = torch.tensor([self.alpha, 1.0 - self.alpha], dtype=torch.float32, device=device)
return torch.clamp(w, min=0) / max(w.sum().item(), 1e-8)
# default uniform
return torch.ones(K, dtype=torch.float32, device=device) / K
def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor:
"""
统一支持:
- x=[B, D], y=[B, D] -> 单视角
- x=[B, K, D], y=[B, D] -> K 个 query 视角对单一候选视角
- x=[B, D], y=[B, K, D] -> 单一 query 视角对 K 个候选视角
- x=[B, K, D], y=[B, K, D] -> 逐视角配对(k↔k)加权
"""
B = x.size(0)
if target is None:
target_per_qry = y.size(0) // B
target = torch.arange(0, B * target_per_qry, target_per_qry, device=x.device, dtype=torch.long)
# 单视角
if x.dim() == 2 and y.dim() == 2:
logits = torch.matmul(x, y.transpose(0, 1)) / self.temperature
return F.cross_entropy(logits, target, reduction=reduction)
# 多视角 query, 单视角 cand
if x.dim() == 3 and y.dim() == 2:
K = x.size(1)
w = self._get_weights(K, x.device)
loss = 0.0
for k in range(K):
logits_k = torch.matmul(x[:, k, :], y.transpose(0, 1)) / self.temperature
loss_k = F.cross_entropy(logits_k, target, reduction=reduction)
loss = loss + w[k] * loss_k
return loss
# 单视角 query, 多视角 cand
if x.dim() == 2 and y.dim() == 3:
K = y.size(1)
w = self._get_weights(K, x.device)
loss = 0.0
for k in range(K):
logits_k = torch.matmul(x, y[:, k, :].transpose(0, 1)) / self.temperature
loss_k = F.cross_entropy(logits_k, target, reduction=reduction)
loss = loss + w[k] * loss_k
return loss
# 多视角配对(k↔k)
if x.dim() == 3 and y.dim() == 3:
Kx, Ky = x.size(1), y.size(1)
assert Kx == Ky, f"view mismatch: {Kx} vs {Ky}"
K = Kx
w = self._get_weights(K, x.device)
loss = 0.0
for k in range(K):
logits_k = torch.matmul(x[:, k, :], y[:, k, :].transpose(0, 1)) / self.temperature
loss_k = F.cross_entropy(logits_k, target, reduction=reduction)
loss = loss + w[k] * loss_k
return loss
raise ValueError(f"Unsupported shapes: x {tuple(x.size())}, y {tuple(y.size())}")
class DistributedContrastiveLoss(SimpleContrastiveLoss):
def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02, alpha: float = 0.05, weights=None):
assert dist.is_initialized(), "Distributed training has not been properly initialized."
super().__init__(temperature=temperature, alpha=alpha, weights=weights)
self.word_size = dist.get_world_size()
self.rank = dist.get_rank()
self.scale_loss = scale_loss
def __call__(self, x: Tensor, y: Tensor, **kwargs):
dist_x = self.gather_tensor(x)
dist_y = self.gather_tensor(y)
loss = super().__call__(dist_x, dist_y, **kwargs)
if self.scale_loss:
loss = loss * self.word_size
return loss
def gather_tensor(self, t):
gathered = [torch.empty_like(t) for _ in range(self.word_size)]
dist.all_gather(gathered, t)
gathered[self.rank] = t # 保留本rank的梯度
return torch.cat(gathered, dim=0)
class InExampleContrastiveLoss:
"""
保持不变
x.shape=[bsz, hdim], y.shape=[bsz, num_label, hdim]
"""
def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, ndim: int = None, *args, **kwargs):
self.target_per_qry = n_hard_negatives + 1
self.temperature = temperature
self.ndim = ndim
def __call__(self, x: Tensor, y: Tensor, reduction: str = 'mean'):
if torch.distributed.is_initialized():
x = dist_utils.dist_gather(x)
y = dist_utils.dist_gather(y)
bsz, ndim = x.size(0), x.size(1)
target = torch.zeros(bsz, dtype=torch.long, device=x.device)
if self.ndim:
ndim = self.ndim
x = x[:, :ndim]
y = y[:, :ndim]
logits = torch.einsum('bod,bsd->bs', x.view(bsz, 1, ndim), y.view(bsz, -1, ndim)) * self.temperature
preds = torch.argmax(logits, dim=-1)
loss = F.cross_entropy(logits, target, reduction=reduction)
loss_detail = {"logits": logits, "labels": target, "preds": preds}
return loss, loss_detail |