| | |
| | import os |
| | from enum import Enum |
| | from typing import Callable, Optional |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from accelerate.utils import gather_object |
| | from torch import nn |
| | from torch.nn import CrossEntropyLoss, MSELoss |
| | from transformers.utils import strtobool |
| |
|
| |
|
| | class LossType: |
| | loss_scale = 'loss_scale' |
| | cosine_similarity = 'cosine_similarity' |
| | contrastive = 'contrastive' |
| | online_contrastive = 'online_contrastive' |
| | infonce = 'infonce' |
| |
|
| |
|
| | LOSS_MAPPING = {} |
| |
|
| |
|
| | def register_loss_func(loss_type: str, loss_func: Optional[Callable] = None): |
| | loss_info = {} |
| |
|
| | if loss_func is not None: |
| | loss_info['loss_func'] = loss_func |
| | LOSS_MAPPING[loss_type] = loss_info |
| | return |
| |
|
| | def _register_loss_func(loss_func: Callable) -> Callable: |
| | loss_info['loss_func'] = loss_func |
| | LOSS_MAPPING[loss_type] = loss_info |
| | return loss_func |
| |
|
| | return _register_loss_func |
| |
|
| |
|
| | def ce_loss_func(outputs, labels): |
| | logits = outputs.logits |
| | device = logits.device |
| | |
| | shift_logits = logits[..., :-1, :] |
| | shift_labels = labels[..., 1:].to(device) |
| | |
| | masks = shift_labels != -100 |
| | shift_logits = shift_logits[masks] |
| | shift_labels = shift_labels[masks] |
| | |
| | loss_fct = CrossEntropyLoss(reduction='none') |
| | loss = loss_fct(shift_logits, shift_labels) |
| | return loss, masks |
| |
|
| |
|
| | |
| | @register_loss_func(LossType.loss_scale) |
| | def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: |
| | """Loss func |
| | |
| | Args: |
| | outputs: The model outputs |
| | labels: The labels |
| | loss_scale: The loss scale |
| | num_items_in_batch: Number of tokens in the labels of gradient accumulation round that are not -100. |
| | |
| | Returns: |
| | |
| | """ |
| | loss, masks = ce_loss_func(outputs, labels) |
| | if loss_scale is not None: |
| | shift_scale = loss_scale[..., 1:].to(masks.device) |
| | shift_scale = shift_scale[masks] |
| | loss = (shift_scale * loss) |
| | if num_items_in_batch is None: |
| | loss = loss.mean() |
| | else: |
| | |
| | loss = loss.sum() / num_items_in_batch |
| | return loss |
| |
|
| |
|
| | def _parse_pair_sentence(outputs): |
| | if isinstance(outputs, dict): |
| | last_hidden_state = outputs['last_hidden_state'] |
| | else: |
| | last_hidden_state = outputs |
| | batch_size = last_hidden_state.shape[0] |
| | shape_len = len(last_hidden_state.shape) |
| | first_sentence = list(range(0, batch_size, 2)) |
| | second_sentence = list(range(1, batch_size, 2)) |
| | if shape_len == 3: |
| | sentence1 = last_hidden_state[first_sentence][:, 0].squeeze(dim=1) |
| | sentence2 = last_hidden_state[second_sentence][:, 0].squeeze(dim=1) |
| | else: |
| | sentence1 = last_hidden_state[first_sentence] |
| | sentence2 = last_hidden_state[second_sentence] |
| | return sentence1, sentence2 |
| |
|
| |
|
| | |
| | class SiameseDistanceMetric(Enum): |
| | """The metric for the contrastive loss""" |
| |
|
| | EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) |
| | MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) |
| | COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) |
| |
|
| |
|
| | @register_loss_func(LossType.cosine_similarity) |
| | def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: |
| | cos_score_transformation = nn.Identity() |
| | loss_fct = MSELoss() |
| | sentence1, sentence2 = _parse_pair_sentence(outputs) |
| | output = cos_score_transformation(torch.cosine_similarity(sentence1, sentence2)) |
| | return loss_fct(output, labels.to(output.dtype).view(-1)) |
| |
|
| |
|
| | @register_loss_func(LossType.contrastive) |
| | def contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: |
| | sentence1, sentence2 = _parse_pair_sentence(outputs) |
| | distance_metric = SiameseDistanceMetric.COSINE_DISTANCE |
| | distances = distance_metric(sentence1, sentence2) |
| | margin = 0.5 |
| | labels = labels.to(sentence1.dtype) |
| | losses = 0.5 * (labels * distances.pow(2) + (1 - labels) * F.relu(margin - distances).pow(2)) |
| | return losses.mean() |
| |
|
| |
|
| | def calculate_paired_metrics(embeddings, labels): |
| | from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, \ |
| | paired_manhattan_distances |
| | from scipy.stats import pearsonr, spearmanr |
| |
|
| | embeddings1, embeddings2 = _parse_pair_sentence(embeddings) |
| | cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2)) |
| | manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2) |
| | euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2) |
| | dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)] |
| |
|
| | eval_pearson_cosine, _ = pearsonr(labels, cosine_scores) |
| | eval_spearman_cosine, _ = spearmanr(labels, cosine_scores) |
| |
|
| | eval_pearson_manhattan, _ = pearsonr(labels, manhattan_distances) |
| | eval_spearman_manhattan, _ = spearmanr(labels, manhattan_distances) |
| |
|
| | eval_pearson_euclidean, _ = pearsonr(labels, euclidean_distances) |
| | eval_spearman_euclidean, _ = spearmanr(labels, euclidean_distances) |
| |
|
| | eval_pearson_dot, _ = pearsonr(labels, dot_products) |
| | eval_spearman_dot, _ = spearmanr(labels, dot_products) |
| |
|
| | return { |
| | 'pearson_cosine': eval_pearson_cosine, |
| | 'pearson_euclidean': eval_pearson_manhattan, |
| | 'pearson_manhattan': eval_pearson_euclidean, |
| | 'pearson_dot_product': eval_pearson_dot, |
| | 'spearman_cosine': eval_spearman_cosine, |
| | 'spearman_euclidean': eval_spearman_manhattan, |
| | 'spearman_manhattan': eval_spearman_euclidean, |
| | 'spearman_dot_product': eval_spearman_dot, |
| | } |
| |
|
| |
|
| | def calculate_infonce_metrics(embeddings, labels): |
| | from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, \ |
| | paired_manhattan_distances |
| | from scipy.stats import pearsonr, spearmanr |
| | hard_negatives = os.environ.get('INFONCE_HARD_NEGATIVES', None) |
| | use_batch = strtobool(os.environ.get('INFONCE_USE_BATCH', 'True')) |
| | split_tensors = _parse_multi_negative_sentences(torch.tensor(embeddings), torch.tensor(labels), hard_negatives) |
| | split_tensors = [t.numpy() for t in split_tensors] |
| | can_batched = hard_negatives is not None |
| | if hard_negatives is None and len(set([s.shape[0] for s in split_tensors])) == 1: |
| | can_batched = True |
| | all_similarity_matrix = [] |
| | all_labels = [] |
| | pos_neg_margins = [] |
| | if not use_batch: |
| | if can_batched: |
| | sentences = np.stack(split_tensors, axis=0) |
| | similarity_matrix = np.matmul(sentences[:, 0:1], sentences[:, 1:].transpose((0, 2, 1))).squeeze(1) |
| | all_similarity_matrix.append(similarity_matrix) |
| | labels = np.zeros_like(similarity_matrix) |
| | labels[:, 0] = 1 |
| | all_labels.append(labels) |
| | else: |
| | for tensor in split_tensors: |
| | similarity_matrix = np.matmul(tensor[0], tensor[1:].T) |
| | all_similarity_matrix.append(similarity_matrix) |
| | labels = np.zeros_like(similarity_matrix) |
| | labels[0] = 1 |
| | all_labels.append(labels) |
| | max_neg_scores = np.max(similarity_matrix[labels == 0], axis=-1) |
| | pos_neg_margins.append(np.mean(similarity_matrix[labels == 1] - max_neg_scores).item()) |
| | else: |
| | if can_batched: |
| | sentences = np.stack(split_tensors, axis=0) |
| | similarity_matrix = np.matmul(sentences[:, 0], sentences[:, 1:].reshape(-1, sentences.shape[2]).T) |
| | all_similarity_matrix.append(similarity_matrix) |
| | labels = np.zeros_like(similarity_matrix) |
| | for row, col in enumerate(range(0, sentences.shape[0] * (sentences.shape[1] - 1), sentences.shape[1] - 1)): |
| | labels[row, col] = 1 |
| | all_labels.append(labels) |
| | else: |
| | all_tensors = [] |
| | for tensor in split_tensors: |
| | all_tensors.append(tensor[1:]) |
| | sentences = np.concatenate(all_tensors, axis=0) |
| | length = 0 |
| | for idx, tensor in enumerate(split_tensors): |
| | similarity_matrix = np.matmul(tensor[0], sentences.T) |
| | all_similarity_matrix.append(similarity_matrix) |
| | labels = np.zeros_like(similarity_matrix) |
| | labels[length] = 1 |
| | all_labels.append(labels) |
| | length += tensor.shape[0] - 1 |
| | max_neg_scores = np.max(similarity_matrix[labels == 0], axis=-1) |
| | pos_neg_margins.append(np.mean(similarity_matrix[labels == 1] - max_neg_scores).item()) |
| |
|
| | similarity_matrix = np.concatenate(all_similarity_matrix, axis=0) |
| | labels = np.concatenate(all_labels, axis=0) |
| | if can_batched: |
| | pos_scores = similarity_matrix[labels == 1].reshape(similarity_matrix.shape[0], -1) |
| | neg_scores = similarity_matrix[labels == 0].reshape(similarity_matrix.shape[0], -1) |
| | max_neg_scores = np.max(neg_scores, axis=-1) |
| | pos_neg_margin = np.mean(pos_scores - max_neg_scores).item() |
| | else: |
| | pos_scores = similarity_matrix[labels == 1] |
| | neg_scores = similarity_matrix[labels == 0] |
| | pos_neg_margin = np.mean(pos_neg_margins) |
| |
|
| | mean_neg = np.mean(neg_scores) |
| | mean_pos = np.mean(pos_scores) |
| | return {'margin': pos_neg_margin, 'mean_neg': mean_neg, 'mean_pos': mean_pos} |
| |
|
| |
|
| | def _parse_multi_negative_sentences(sentences, labels, hard_negatives=None): |
| | split_indices = torch.nonzero(labels, as_tuple=False).squeeze().tolist() |
| | if isinstance(split_indices, int): |
| | split_indices = [split_indices] |
| | split_indices.append(len(labels)) |
| | split_indices = np.array(split_indices) + np.array(list(range(len(split_indices)))) |
| | split_tensors = [] |
| |
|
| | for i in range(len(split_indices) - 1): |
| | start = split_indices[i] |
| | end = split_indices[i + 1] |
| | split_part = sentences[start:end] |
| | if hard_negatives is not None: |
| | negatives = len(split_part) - 2 |
| | assert negatives > 0 |
| | if negatives > hard_negatives: |
| | split_part = split_part[:hard_negatives + 2] |
| | elif negatives < hard_negatives: |
| | selected = np.random.choice(list(range(negatives)), size=hard_negatives - negatives, replace=True) |
| | selected += 1 |
| | split_part = torch.cat((split_part, split_part[selected]), dim=0) |
| | split_tensors.append(split_part) |
| | return split_tensors |
| |
|
| |
|
| | @register_loss_func(LossType.infonce) |
| | def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: |
| | temperature = float(os.environ.get('INFONCE_TEMPERATURE', '0.01')) |
| | |
| | use_batch = strtobool(os.environ.get('INFONCE_USE_BATCH', 'True')) |
| | hard_negatives = os.environ.get('INFONCE_HARD_NEGATIVES', None) |
| | |
| | infonce_mask_fake_negative = strtobool(os.environ.get('INFONCE_MASK_FAKE_NEGATIVE', 'False')) |
| | if hard_negatives is not None: |
| | hard_negatives = int(hard_negatives) |
| | from swift.utils import get_dist_setting |
| | rank, _, world_size, _ = get_dist_setting() |
| | |
| | sentences = outputs['last_hidden_state'] |
| |
|
| | if world_size > 1 and use_batch: |
| | |
| | all_sentences = gather_object(sentences.unsqueeze(0)) |
| | labels = gather_object(labels) |
| | |
| | all_sentences[rank] = sentences |
| | for idx in range(len(all_sentences)): |
| | if idx == rank: |
| | continue |
| | |
| | all_sentences[idx] = all_sentences[idx].detach().to(sentences.device) |
| | sentences = torch.cat(all_sentences, dim=0) |
| | labels = [tensor.to(sentences.device) for tensor in labels] |
| | labels = torch.stack(labels, dim=0) |
| |
|
| | |
| | |
| | |
| | split_tensors = _parse_multi_negative_sentences(sentences, labels, hard_negatives) |
| | loss = 0 |
| | can_batched = hard_negatives is not None |
| | if hard_negatives is None and len(set([s.shape[0] for s in split_tensors])) == 1: |
| | |
| | can_batched = True |
| | if not use_batch: |
| | |
| | if can_batched: |
| | |
| | |
| | sentences = torch.stack(split_tensors, dim=0) |
| | |
| | similarity_matrix = torch.matmul(sentences[:, 0:1], sentences[:, 1:].transpose(1, 2)) / temperature |
| | |
| | labels = torch.zeros(len(split_tensors), dtype=torch.int64).to(sentences.device) |
| | loss = nn.CrossEntropyLoss()(similarity_matrix.squeeze(1), labels) |
| | else: |
| | |
| | for tensor in split_tensors: |
| | |
| | similarity_matrix = torch.matmul(tensor[0], tensor[1:].T) / temperature |
| | |
| | labels = torch.tensor(0).to(tensor.device) |
| | loss += nn.CrossEntropyLoss()(similarity_matrix, labels) |
| | |
| | loss /= len(split_tensors) |
| | else: |
| |
|
| | def mask_fake_negative(sim_matrix, sim_labels): |
| | thresholds = sim_matrix[torch.arange(sim_matrix.size(0)), sim_labels].view(-1, 1) + 0.1 |
| | thresholds = thresholds.detach() |
| | mask = sim_matrix > thresholds |
| | sim_matrix[mask] = float('-inf') |
| |
|
| | if can_batched: |
| | |
| | sentences = torch.stack(split_tensors, dim=0) |
| | |
| | similarity_matrix = torch.matmul(sentences[:, 0].squeeze(1), sentences[:, |
| | 1:].reshape(-1, sentences.size(2)).T) |
| | labels = torch.tensor(range(0, |
| | sentences.size(0) * (sentences.size(1) - 1), |
| | sentences.size(1) - 1)).view(-1).to(sentences.device) |
| | if infonce_mask_fake_negative: |
| | mask_fake_negative(similarity_matrix, labels) |
| | similarity_matrix = similarity_matrix / temperature |
| | |
| | loss = nn.CrossEntropyLoss()(similarity_matrix, labels) / world_size |
| | else: |
| | all_tensors = [] |
| | for tensor in split_tensors: |
| | all_tensors.append(tensor[1:]) |
| | |
| | sentences = torch.cat(all_tensors, dim=0) |
| | length = 0 |
| | for idx, tensor in enumerate(split_tensors): |
| | |
| | similarity_matrix = torch.matmul(tensor[0], sentences.T) / temperature |
| | labels = torch.tensor(length).to(tensor.device) |
| | loss += nn.CrossEntropyLoss()(similarity_matrix, labels) |
| | |
| | length += tensor.size(0) - 1 |
| | loss /= len(split_tensors) |
| | loss /= world_size |
| | return loss |
| |
|
| |
|
| | @register_loss_func(LossType.online_contrastive) |
| | def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: |
| | sentence1, sentence2 = _parse_pair_sentence(outputs) |
| | distance_metric = SiameseDistanceMetric.COSINE_DISTANCE |
| | distance_matrix = distance_metric(sentence1, sentence2) |
| | negs = distance_matrix[labels == 0] |
| | poss = distance_matrix[labels == 1] |
| |
|
| | |
| | negative_pairs = negs[negs < (poss.max() if len(poss) > 1 else negs.mean())] |
| | positive_pairs = poss[poss > (negs.min() if len(negs) > 1 else poss.mean())] |
| |
|
| | positive_loss = positive_pairs.pow(2).sum() |
| | margin = 0.5 |
| | negative_loss = F.relu(margin - negative_pairs).pow(2).sum() |
| | loss = positive_loss + negative_loss |
| | return loss |
| |
|
| |
|
| | def get_loss_func(loss_type: Optional[str]) -> Optional[Callable]: |
| | if loss_type is None: |
| | return None |
| | return LOSS_MAPPING[loss_type]['loss_func'] |
| |
|