code_SAS_VLM2Vec / src /loss_multi_layer.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
from torch import Tensor
import torch.distributed as dist
import torch
import torch.nn.functional as F
import os # 新增
class SimpleContrastiveLoss:
def __init__(self, temperature: float = 0.02, alpha: float = 0.05, weights=None):
"""
weights: list[float] or None
- 若提供 weights,则用于对多个视角/层的 CE 加权,长度需与视角数一致,训练时会归一化。
- 若为 None 且 K==2,退化为 [alpha, 1-alpha]
- 若为 None 且 K>2,默认均匀权重
"""
self.temperature = temperature
self.alpha = alpha
self.weights = weights # e.g. [0.1, 0.2, 0.7]
def _get_weights(self, K: int, device):
if self.weights is not None:
assert len(self.weights) == K, f"weights length {len(self.weights)} != K={K}"
w = torch.tensor(self.weights, dtype=torch.float32, device=device)
w = torch.clamp(w, min=0)
s = w.sum().item()
if s <= 0:
w = torch.ones(K, device=device) / K
else:
w = w / s
return w
if K == 2:
w = torch.tensor([self.alpha, 1.0 - self.alpha], dtype=torch.float32, device=device)
return torch.clamp(w, min=0) / max(w.sum().item(), 1e-8)
# default uniform
return torch.ones(K, dtype=torch.float32, device=device) / K
def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor:
"""
统一支持:
- x=[B, D], y=[B, D] -> 单视角
- x=[B, K, D], y=[B, D] -> K 个 query 视角对单一候选视角
- x=[B, D], y=[B, K, D] -> 单一 query 视角对 K 个候选视角
- x=[B, K, D], y=[B, K, D] -> 逐视角配对(k↔k)加权
"""
B = x.size(0)
if target is None:
target_per_qry = y.size(0) // B
target = torch.arange(0, B * target_per_qry, target_per_qry, device=x.device, dtype=torch.long)
# 单视角
if x.dim() == 2 and y.dim() == 2:
logits = torch.matmul(x, y.transpose(0, 1)) / self.temperature
return F.cross_entropy(logits, target, reduction=reduction)
# 多视角 query, 单视角 cand
if x.dim() == 3 and y.dim() == 2:
K = x.size(1)
w = self._get_weights(K, x.device)
loss = 0.0
for k in range(K):
logits_k = torch.matmul(x[:, k, :], y.transpose(0, 1)) / self.temperature
loss_k = F.cross_entropy(logits_k, target, reduction=reduction)
loss = loss + w[k] * loss_k
return loss
# 单视角 query, 多视角 cand
if x.dim() == 2 and y.dim() == 3:
K = y.size(1)
w = self._get_weights(K, x.device)
loss = 0.0
for k in range(K):
logits_k = torch.matmul(x, y[:, k, :].transpose(0, 1)) / self.temperature
loss_k = F.cross_entropy(logits_k, target, reduction=reduction)
loss = loss + w[k] * loss_k
return loss
# 多视角配对(k↔k)
if x.dim() == 3 and y.dim() == 3:
Kx, Ky = x.size(1), y.size(1)
assert Kx == Ky, f"view mismatch: {Kx} vs {Ky}"
K = Kx
w = self._get_weights(K, x.device)
loss = 0.0
for k in range(K):
logits_k = torch.matmul(x[:, k, :], y[:, k, :].transpose(0, 1)) / self.temperature
loss_k = F.cross_entropy(logits_k, target, reduction=reduction)
loss = loss + w[k] * loss_k
return loss
raise ValueError(f"Unsupported shapes: x {tuple(x.size())}, y {tuple(y.size())}")
class DistributedContrastiveLoss(SimpleContrastiveLoss):
def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02, alpha: float = 0.05, weights=None):
assert dist.is_initialized(), "Distributed training has not been properly initialized."
super().__init__(temperature=temperature, alpha=alpha, weights=weights)
self.word_size = dist.get_world_size()
self.rank = dist.get_rank()
self.scale_loss = scale_loss
def __call__(self, x: Tensor, y: Tensor, **kwargs):
dist_x = self.gather_tensor(x)
dist_y = self.gather_tensor(y)
loss = super().__call__(dist_x, dist_y, **kwargs)
if self.scale_loss:
loss = loss * self.word_size
return loss
def gather_tensor(self, t):
gathered = [torch.empty_like(t) for _ in range(self.word_size)]
dist.all_gather(gathered, t)
gathered[self.rank] = t # 保留本rank的梯度
return torch.cat(gathered, dim=0)
class InExampleContrastiveLoss:
"""
保持不变
x.shape=[bsz, hdim], y.shape=[bsz, num_label, hdim]
"""
def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, ndim: int = None, *args, **kwargs):
self.target_per_qry = n_hard_negatives + 1
self.temperature = temperature
self.ndim = ndim
def __call__(self, x: Tensor, y: Tensor, reduction: str = 'mean'):
if torch.distributed.is_initialized():
x = dist_utils.dist_gather(x)
y = dist_utils.dist_gather(y)
bsz, ndim = x.size(0), x.size(1)
target = torch.zeros(bsz, dtype=torch.long, device=x.device)
if self.ndim:
ndim = self.ndim
x = x[:, :ndim]
y = y[:, :ndim]
logits = torch.einsum('bod,bsd->bs', x.view(bsz, 1, ndim), y.view(bsz, -1, ndim)) * self.temperature
preds = torch.argmax(logits, dim=-1)
loss = F.cross_entropy(logits, target, reduction=reduction)
loss_detail = {"logits": logits, "labels": target, "preds": preds}
return loss, loss_detail