File size: 5,909 Bytes
0a937d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from torch import Tensor
import torch.distributed as dist
import torch
import torch.nn.functional as F
import os  # 新增

class SimpleContrastiveLoss:
    def __init__(self, temperature: float = 0.02, alpha: float = 0.05, weights=None):
        """
        weights: list[float] or None
        - 若提供 weights,则用于对多个视角/层的 CE 加权,长度需与视角数一致,训练时会归一化。
        - 若为 None 且 K==2,退化为 [alpha, 1-alpha]
        - 若为 None 且 K>2,默认均匀权重
        """
        self.temperature = temperature
        self.alpha = alpha
        self.weights = weights  # e.g. [0.1, 0.2, 0.7]

    def _get_weights(self, K: int, device):
        if self.weights is not None:
            assert len(self.weights) == K, f"weights length {len(self.weights)} != K={K}"
            w = torch.tensor(self.weights, dtype=torch.float32, device=device)
            w = torch.clamp(w, min=0)
            s = w.sum().item()
            if s <= 0:
                w = torch.ones(K, device=device) / K
            else:
                w = w / s
            return w
        if K == 2:
            w = torch.tensor([self.alpha, 1.0 - self.alpha], dtype=torch.float32, device=device)
            return torch.clamp(w, min=0) / max(w.sum().item(), 1e-8)
        # default uniform
        return torch.ones(K, dtype=torch.float32, device=device) / K

    def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor:
        """
        统一支持:
        - x=[B, D],   y=[B, D]          -> 单视角
        - x=[B, K, D], y=[B, D]         -> K 个 query 视角对单一候选视角
        - x=[B, D],   y=[B, K, D]       -> 单一 query 视角对 K 个候选视角
        - x=[B, K, D], y=[B, K, D]      -> 逐视角配对(k↔k)加权
        """
        B = x.size(0)
        if target is None:
            target_per_qry = y.size(0) // B
            target = torch.arange(0, B * target_per_qry, target_per_qry, device=x.device, dtype=torch.long)

        # 单视角
        if x.dim() == 2 and y.dim() == 2:
            logits = torch.matmul(x, y.transpose(0, 1)) / self.temperature
            return F.cross_entropy(logits, target, reduction=reduction)

        # 多视角 query, 单视角 cand
        if x.dim() == 3 and y.dim() == 2:
            K = x.size(1)
            w = self._get_weights(K, x.device)
            loss = 0.0
            for k in range(K):
                logits_k = torch.matmul(x[:, k, :], y.transpose(0, 1)) / self.temperature
                loss_k = F.cross_entropy(logits_k, target, reduction=reduction)
                loss = loss + w[k] * loss_k
            return loss

        # 单视角 query, 多视角 cand
        if x.dim() == 2 and y.dim() == 3:
            K = y.size(1)
            w = self._get_weights(K, x.device)
            loss = 0.0
            for k in range(K):
                logits_k = torch.matmul(x, y[:, k, :].transpose(0, 1)) / self.temperature
                loss_k = F.cross_entropy(logits_k, target, reduction=reduction)
                loss = loss + w[k] * loss_k
            return loss

        # 多视角配对(k↔k)
        if x.dim() == 3 and y.dim() == 3:
            Kx, Ky = x.size(1), y.size(1)
            assert Kx == Ky, f"view mismatch: {Kx} vs {Ky}"
            K = Kx
            w = self._get_weights(K, x.device)
            loss = 0.0
            for k in range(K):
                logits_k = torch.matmul(x[:, k, :], y[:, k, :].transpose(0, 1)) / self.temperature
                loss_k = F.cross_entropy(logits_k, target, reduction=reduction)
                loss = loss + w[k] * loss_k
            return loss

        raise ValueError(f"Unsupported shapes: x {tuple(x.size())}, y {tuple(y.size())}")


class DistributedContrastiveLoss(SimpleContrastiveLoss):
    def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02, alpha: float = 0.05, weights=None):
        assert dist.is_initialized(), "Distributed training has not been properly initialized."
        super().__init__(temperature=temperature, alpha=alpha, weights=weights)
        self.word_size = dist.get_world_size()
        self.rank = dist.get_rank()
        self.scale_loss = scale_loss

    def __call__(self, x: Tensor, y: Tensor, **kwargs):
        dist_x = self.gather_tensor(x)
        dist_y = self.gather_tensor(y)
        loss = super().__call__(dist_x, dist_y, **kwargs)
        if self.scale_loss:
            loss = loss * self.word_size
        return loss

    def gather_tensor(self, t):
        gathered = [torch.empty_like(t) for _ in range(self.word_size)]
        dist.all_gather(gathered, t)
        gathered[self.rank] = t  # 保留本rank的梯度
        return torch.cat(gathered, dim=0)

class InExampleContrastiveLoss:
    """
    保持不变
    x.shape=[bsz, hdim], y.shape=[bsz, num_label, hdim]
    """
    def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, ndim: int = None, *args, **kwargs):
        self.target_per_qry = n_hard_negatives + 1
        self.temperature = temperature
        self.ndim = ndim

    def __call__(self, x: Tensor, y: Tensor, reduction: str = 'mean'):
        if torch.distributed.is_initialized():
            x = dist_utils.dist_gather(x)
            y = dist_utils.dist_gather(y)
        bsz, ndim = x.size(0), x.size(1)
        target = torch.zeros(bsz, dtype=torch.long, device=x.device)
        if self.ndim:
            ndim = self.ndim
            x = x[:, :ndim]
            y = y[:, :ndim]
        logits = torch.einsum('bod,bsd->bs', x.view(bsz, 1, ndim), y.view(bsz, -1, ndim)) * self.temperature
        preds = torch.argmax(logits, dim=-1)
        loss = F.cross_entropy(logits, target, reduction=reduction)
        loss_detail = {"logits": logits, "labels": target, "preds": preds}
        return loss, loss_detail