from torch import Tensor import torch.distributed as dist import torch import torch.nn.functional as F import os # 新增 # class SimpleContrastiveLoss: # def __init__(self, temperature: float = 0.02, alpha: float = 0.2): # self.temperature = temperature # self.alpha = alpha # def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor: # """ # - 常规:x=[B, D], y=[B, D],做单向 CE(InfoNCE) # - 新增:x=[B, 2, D], y=[B, D],视为 query 的第20层与最后一层,分别算 CE 再加权 # """ # # 计算 target(允许 y 比 x 多倍,用于多正样本的情况) # B = x.size(0) if x.dim() >= 2 else x.size(0) # if target is None: # target_per_qry = y.size(0) // B # target = torch.arange( # 0, B * target_per_qry, target_per_qry, device=x.device, dtype=torch.long # ) # if x.dim() == 2: # # 原行为 # logits = torch.matmul(x, y.transpose(0, 1)) # loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction) # return loss # # 新行为:x 为 [B, 2, D],分别计算两次 CE # assert x.dim() == 3 and x.size(1) == 2, f"Expect x=[B,2,D], got {tuple(x.size())}" # q20, qlast = x[:, 0, :], x[:, 1, :] # logits20 = torch.matmul(q20, y.transpose(0, 1)) / self.temperature # logits_last = torch.matmul(qlast, y.transpose(0, 1)) / self.temperature # loss20 = F.cross_entropy(logits20, target, reduction=reduction) # loss_last = F.cross_entropy(logits_last, target, reduction=reduction) # # print('loss20:', loss20) # # print('loss_last:', loss_last) # # print('self.alpha:', self.alpha) # loss = self.alpha * loss20 + (1.0 - self.alpha) * loss_last # # print('loss', loss) # # exit() # return loss # class DistributedContrastiveLoss(SimpleContrastiveLoss): # def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02, alpha: float = 0.2): # assert dist.is_initialized(), "Distributed training has not been properly initialized." # super().__init__(temperature=temperature, alpha=alpha) # self.word_size = dist.get_world_size() # self.rank = dist.get_rank() # self.scale_loss = scale_loss # def __call__(self, x: Tensor, y: Tensor, **kwargs): # dist_x = self.gather_tensor(x) # dist_y = self.gather_tensor(y) # loss = super().__call__(dist_x, dist_y, **kwargs) # if self.scale_loss: # loss = loss * self.word_size # return loss # def gather_tensor(self, t): # gathered = [torch.empty_like(t) for _ in range(self.word_size)] # dist.all_gather(gathered, t) # gathered[self.rank] = t # 保留本rank的梯度 # return torch.cat(gathered, dim=0) class SimpleContrastiveLoss: def __init__(self, temperature: float = 0.02, alpha: float = 0.05): self.temperature = temperature self.alpha = alpha # 视角0(第20层)的权重,视角1(最后一层)权重=1-alpha def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor: """ - 常规:x=[B, D], y=[B, D] -> 单向 CE(InfoNCE) - 扩展: * x=[B, 2, D], y=[B, D] -> 两个 query 视角各自对 y 做 CE,并按 alpha 加权 * x=[B, D], y=[B, 2, D] -> x 分别对 y 的两个视角做 CE,并按 alpha 加权 * x=[B, 2, D], y=[B, 2, D] -> 匹配视角(0↔0, 1↔1)各算 CE 加权求和 """ B = x.size(0) if target is None: target_per_qry = y.size(0) // B target = torch.arange( 0, B * target_per_qry, target_per_qry, device=x.device, dtype=torch.long ) # 单视角 x,y if x.dim() == 2 and y.dim() == 2: logits = torch.matmul(x, y.transpose(0, 1)) / self.temperature loss = F.cross_entropy(logits, target, reduction=reduction) return loss # [修改] x为双视角, y为单视角 if x.dim() == 3 and y.dim() == 2: assert x.size(1) == 2, f"Expect x=[B,2,D], got {tuple(x.size())}" w0, w1 = self.alpha, 1.0 - self.alpha q0, q1 = x[:, 0, :], x[:, 1, :] logits0 = torch.matmul(q0, y.transpose(0, 1)) / self.temperature logits1 = torch.matmul(q1, y.transpose(0, 1)) / self.temperature loss0 = F.cross_entropy(logits0, target, reduction=reduction) loss1 = F.cross_entropy(logits1, target, reduction=reduction) return w0 * loss0 + w1 * loss1 # [修改] x为单视角, y为双视角 if x.dim() == 2 and y.dim() == 3: assert y.size(1) == 2, f"Expect y=[B,2,D], got {tuple(y.size())}" w0, w1 = self.alpha, 1.0 - self.alpha p0, p1 = y[:, 0, :], y[:, 1, :] logits0 = torch.matmul(x, p0.transpose(0, 1)) / self.temperature logits1 = torch.matmul(x, p1.transpose(0, 1)) / self.temperature loss0 = F.cross_entropy(logits0, target, reduction=reduction) loss1 = F.cross_entropy(logits1, target, reduction=reduction) return w0 * loss0 + w1 * loss1 # [修改] 匹配视角:x,y均为双视角 if x.dim() == 3 and y.dim() == 3: assert x.size(1) == y.size(1) == 2, f"Expect x,y=[B,2,D], got {tuple(x.size())}, {tuple(y.size())}" w0, w1 = self.alpha, 1.0 - self.alpha q0, q1 = x[:, 0, :], x[:, 1, :] p0, p1 = y[:, 0, :], y[:, 1, :] logits0 = torch.matmul(q0, p0.transpose(0, 1)) / self.temperature logits1 = torch.matmul(q1, p1.transpose(0, 1)) / self.temperature loss0 = F.cross_entropy(logits0, target, reduction=reduction) loss1 = F.cross_entropy(logits1, target, reduction=reduction) return w0 * loss0 + w1 * loss1 raise ValueError(f"Unsupported shapes: x {tuple(x.size())}, y {tuple(y.size())}") class DistributedContrastiveLoss(SimpleContrastiveLoss): def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02, alpha: float = 0.05): assert dist.is_initialized(), "Distributed training has not been properly initialized." super().__init__(temperature=temperature, alpha=alpha) self.word_size = dist.get_world_size() self.rank = dist.get_rank() self.scale_loss = scale_loss def __call__(self, x: Tensor, y: Tensor, **kwargs): dist_x = self.gather_tensor(x) dist_y = self.gather_tensor(y) loss = super().__call__(dist_x, dist_y, **kwargs) if self.scale_loss: loss = loss * self.word_size return loss def gather_tensor(self, t): gathered = [torch.empty_like(t) for _ in range(self.word_size)] dist.all_gather(gathered, t) gathered[self.rank] = t # 保留本rank的梯度 return torch.cat(gathered, dim=0) class InExampleContrastiveLoss: """ 保持不变 x.shape=[bsz, hdim], y.shape=[bsz, num_label, hdim] """ def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, ndim: int = None, *args, **kwargs): self.target_per_qry = n_hard_negatives + 1 self.temperature = temperature self.ndim = ndim def __call__(self, x: Tensor, y: Tensor, reduction: str = 'mean'): if torch.distributed.is_initialized(): x = dist_utils.dist_gather(x) y = dist_utils.dist_gather(y) bsz, ndim = x.size(0), x.size(1) target = torch.zeros(bsz, dtype=torch.long, device=x.device) if self.ndim: ndim = self.ndim x = x[:, :ndim] y = y[:, :ndim] logits = torch.einsum('bod,bsd->bs', x.view(bsz, 1, ndim), y.view(bsz, -1, ndim)) * self.temperature preds = torch.argmax(logits, dim=-1) loss = F.cross_entropy(logits, target, reduction=reduction) loss_detail = {"logits": logits, "labels": target, "preds": preds} return loss, loss_detail