code_SAS_VLM2Vec / src /loss_layer_prune.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
from torch import Tensor
import torch.distributed as dist
import torch
import torch.nn.functional as F
import os # 新增
# class SimpleContrastiveLoss:
# def __init__(self, temperature: float = 0.02, alpha: float = 0.2):
# self.temperature = temperature
# self.alpha = alpha
# def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor:
# """
# - 常规:x=[B, D], y=[B, D],做单向 CE(InfoNCE)
# - 新增:x=[B, 2, D], y=[B, D],视为 query 的第20层与最后一层,分别算 CE 再加权
# """
# # 计算 target(允许 y 比 x 多倍,用于多正样本的情况)
# B = x.size(0) if x.dim() >= 2 else x.size(0)
# if target is None:
# target_per_qry = y.size(0) // B
# target = torch.arange(
# 0, B * target_per_qry, target_per_qry, device=x.device, dtype=torch.long
# )
# if x.dim() == 2:
# # 原行为
# logits = torch.matmul(x, y.transpose(0, 1))
# loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction)
# return loss
# # 新行为:x 为 [B, 2, D],分别计算两次 CE
# assert x.dim() == 3 and x.size(1) == 2, f"Expect x=[B,2,D], got {tuple(x.size())}"
# q20, qlast = x[:, 0, :], x[:, 1, :]
# logits20 = torch.matmul(q20, y.transpose(0, 1)) / self.temperature
# logits_last = torch.matmul(qlast, y.transpose(0, 1)) / self.temperature
# loss20 = F.cross_entropy(logits20, target, reduction=reduction)
# loss_last = F.cross_entropy(logits_last, target, reduction=reduction)
# # print('loss20:', loss20)
# # print('loss_last:', loss_last)
# # print('self.alpha:', self.alpha)
# loss = self.alpha * loss20 + (1.0 - self.alpha) * loss_last
# # print('loss', loss)
# # exit()
# return loss
# class DistributedContrastiveLoss(SimpleContrastiveLoss):
# def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02, alpha: float = 0.2):
# assert dist.is_initialized(), "Distributed training has not been properly initialized."
# super().__init__(temperature=temperature, alpha=alpha)
# self.word_size = dist.get_world_size()
# self.rank = dist.get_rank()
# self.scale_loss = scale_loss
# def __call__(self, x: Tensor, y: Tensor, **kwargs):
# dist_x = self.gather_tensor(x)
# dist_y = self.gather_tensor(y)
# loss = super().__call__(dist_x, dist_y, **kwargs)
# if self.scale_loss:
# loss = loss * self.word_size
# return loss
# def gather_tensor(self, t):
# gathered = [torch.empty_like(t) for _ in range(self.word_size)]
# dist.all_gather(gathered, t)
# gathered[self.rank] = t # 保留本rank的梯度
# return torch.cat(gathered, dim=0)
class SimpleContrastiveLoss:
def __init__(self, temperature: float = 0.02, alpha: float = 0.05):
self.temperature = temperature
self.alpha = alpha # 视角0(第20层)的权重,视角1(最后一层)权重=1-alpha
def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor:
"""
- 常规:x=[B, D], y=[B, D] -> 单向 CE(InfoNCE)
- 扩展:
* x=[B, 2, D], y=[B, D] -> 两个 query 视角各自对 y 做 CE,并按 alpha 加权
* x=[B, D], y=[B, 2, D] -> x 分别对 y 的两个视角做 CE,并按 alpha 加权
* x=[B, 2, D], y=[B, 2, D] -> 匹配视角(0↔0, 1↔1)各算 CE 加权求和
"""
B = x.size(0)
if target is None:
target_per_qry = y.size(0) // B
target = torch.arange(
0, B * target_per_qry, target_per_qry, device=x.device, dtype=torch.long
)
# 单视角 x,y
if x.dim() == 2 and y.dim() == 2:
logits = torch.matmul(x, y.transpose(0, 1)) / self.temperature
loss = F.cross_entropy(logits, target, reduction=reduction)
return loss
# [修改] x为双视角, y为单视角
if x.dim() == 3 and y.dim() == 2:
assert x.size(1) == 2, f"Expect x=[B,2,D], got {tuple(x.size())}"
w0, w1 = self.alpha, 1.0 - self.alpha
q0, q1 = x[:, 0, :], x[:, 1, :]
logits0 = torch.matmul(q0, y.transpose(0, 1)) / self.temperature
logits1 = torch.matmul(q1, y.transpose(0, 1)) / self.temperature
loss0 = F.cross_entropy(logits0, target, reduction=reduction)
loss1 = F.cross_entropy(logits1, target, reduction=reduction)
return w0 * loss0 + w1 * loss1
# [修改] x为单视角, y为双视角
if x.dim() == 2 and y.dim() == 3:
assert y.size(1) == 2, f"Expect y=[B,2,D], got {tuple(y.size())}"
w0, w1 = self.alpha, 1.0 - self.alpha
p0, p1 = y[:, 0, :], y[:, 1, :]
logits0 = torch.matmul(x, p0.transpose(0, 1)) / self.temperature
logits1 = torch.matmul(x, p1.transpose(0, 1)) / self.temperature
loss0 = F.cross_entropy(logits0, target, reduction=reduction)
loss1 = F.cross_entropy(logits1, target, reduction=reduction)
return w0 * loss0 + w1 * loss1
# [修改] 匹配视角:x,y均为双视角
if x.dim() == 3 and y.dim() == 3:
assert x.size(1) == y.size(1) == 2, f"Expect x,y=[B,2,D], got {tuple(x.size())}, {tuple(y.size())}"
w0, w1 = self.alpha, 1.0 - self.alpha
q0, q1 = x[:, 0, :], x[:, 1, :]
p0, p1 = y[:, 0, :], y[:, 1, :]
logits0 = torch.matmul(q0, p0.transpose(0, 1)) / self.temperature
logits1 = torch.matmul(q1, p1.transpose(0, 1)) / self.temperature
loss0 = F.cross_entropy(logits0, target, reduction=reduction)
loss1 = F.cross_entropy(logits1, target, reduction=reduction)
return w0 * loss0 + w1 * loss1
raise ValueError(f"Unsupported shapes: x {tuple(x.size())}, y {tuple(y.size())}")
class DistributedContrastiveLoss(SimpleContrastiveLoss):
def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02, alpha: float = 0.05):
assert dist.is_initialized(), "Distributed training has not been properly initialized."
super().__init__(temperature=temperature, alpha=alpha)
self.word_size = dist.get_world_size()
self.rank = dist.get_rank()
self.scale_loss = scale_loss
def __call__(self, x: Tensor, y: Tensor, **kwargs):
dist_x = self.gather_tensor(x)
dist_y = self.gather_tensor(y)
loss = super().__call__(dist_x, dist_y, **kwargs)
if self.scale_loss:
loss = loss * self.word_size
return loss
def gather_tensor(self, t):
gathered = [torch.empty_like(t) for _ in range(self.word_size)]
dist.all_gather(gathered, t)
gathered[self.rank] = t # 保留本rank的梯度
return torch.cat(gathered, dim=0)
class InExampleContrastiveLoss:
"""
保持不变
x.shape=[bsz, hdim], y.shape=[bsz, num_label, hdim]
"""
def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, ndim: int = None, *args, **kwargs):
self.target_per_qry = n_hard_negatives + 1
self.temperature = temperature
self.ndim = ndim
def __call__(self, x: Tensor, y: Tensor, reduction: str = 'mean'):
if torch.distributed.is_initialized():
x = dist_utils.dist_gather(x)
y = dist_utils.dist_gather(y)
bsz, ndim = x.size(0), x.size(1)
target = torch.zeros(bsz, dtype=torch.long, device=x.device)
if self.ndim:
ndim = self.ndim
x = x[:, :ndim]
y = y[:, :ndim]
logits = torch.einsum('bod,bsd->bs', x.view(bsz, 1, ndim), y.view(bsz, -1, ndim)) * self.temperature
preds = torch.argmax(logits, dim=-1)
loss = F.cross_entropy(logits, target, reduction=reduction)
loss_detail = {"logits": logits, "labels": target, "preds": preds}
return loss, loss_detail