| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from typing import Tuple, Optional, List |
| | from dataclasses import dataclass |
| |
|
| | from transformers import PreTrainedModel |
| | from transformers.utils import ModelOutput |
| |
|
| | from .configuration_compression import CompressionConfig |
| |
|
| | def cosine_pairwise(embeddings): |
| | return F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2) |
| |
|
| | def cov(tensor, rowvar=True, bias=False): |
| | """Estimate a covariance matrix (np.cov)""" |
| | tensor = tensor if rowvar else tensor.transpose(-1, -2) |
| | tensor = tensor - tensor.mean(dim=-1, keepdim=True) |
| | factor = 1 / (tensor.shape[-1] - int(not bool(bias))) |
| | return factor * tensor @ tensor.transpose(-1, -2).conj() |
| |
|
| | def remove_diag(x): |
| | n = x.shape[0] |
| | return x.masked_select(~torch.eye(n, dtype=bool, device=x.device)).view(n, n - 1) |
| |
|
| | def corrcoef(tensor, rowvar=True): |
| | """Get Pearson product-moment correlation coefficients (np.corrcoef)""" |
| | covariance = cov(tensor, rowvar=rowvar) |
| | variance = covariance.diagonal(0, -1, -2) |
| | if variance.is_complex(): |
| | variance = variance.real |
| | stddev = variance.sqrt() |
| | covariance /= stddev.unsqueeze(-1) |
| | covariance /= stddev.unsqueeze(-2) |
| | if covariance.is_complex(): |
| | covariance.real.clip_(-1, 1) |
| | covariance.imag.clip_(-1, 1) |
| | else: |
| | covariance.clip_(-1, 1) |
| | return covariance |
| |
|
| | def compute_correlation(base_sims, compressed_sims, rm_diag=True): |
| | if rm_diag: |
| | base_sims = remove_diag(base_sims) |
| | compressed_sims = remove_diag(compressed_sims) |
| |
|
| | inputs = torch.stack([base_sims, |
| | compressed_sims], dim=1) |
| | return (1-corrcoef(inputs)[:, 0, 1]).mean() |
| |
|
| | def loss_function(base_sims, compressed_sims, k_vals): |
| | outputs = [compute_correlation(base_sims, compressed_sims)] |
| |
|
| | if k_vals: |
| | base_ranks = base_sims.argsort(-1, descending=True)[:, 1:] |
| | n = base_ranks.shape[1] |
| | for k in k_vals: |
| | base_sims_k = torch.gather(base_sims, 1, base_ranks[:, :k]) |
| | compressed_sims_k = torch.gather(compressed_sims, 1, base_ranks[:, :k]) |
| | outputs.append(compute_correlation(base_sims_k, compressed_sims_k, rm_diag=False)) |
| |
|
| | return torch.stack(outputs).unsqueeze(0) |
| |
|
| | class FeedForward(nn.Module): |
| | def __init__(self, d_in, d_out): |
| | super().__init__() |
| | self.fc1 = nn.Linear(d_in, d_out*2) |
| | self.fc2 = nn.Linear(d_out, d_out) |
| | |
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x1, x2 = x.chunk(2, dim=-1) |
| | x = self.fc2(F.silu(x1) * x2) |
| | return x |
| |
|
| | class CompressionHead(nn.Module): |
| | def __init__(self, d_in, d_out, dropout=0.1): |
| | super().__init__() |
| | self.ff = FeedForward(d_in, d_out) |
| | self.skip = nn.Linear(d_in, d_out) |
| | self.dropout = nn.Dropout(dropout) |
| | |
| | def forward(self, x): |
| | x = self.dropout(x) |
| | x = self.ff(x) + self.skip(x) |
| | return x |
| |
|
| | @dataclass |
| | class CompressionModelOutput(ModelOutput): |
| | loss: Optional[torch.FloatTensor] = None |
| | losses: Optional[List[torch.FloatTensor]] = None |
| | base_embedding: Optional[torch.FloatTensor] = None |
| | compressed_embeddings: Optional[List[torch.FloatTensor]] = None |
| |
|
| | class CompressionModel(PreTrainedModel): |
| | config_class = CompressionConfig |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.heads = nn.ModuleList([CompressionHead(config.input_size, i, config.dropout) |
| | for i in config.compression_sizes]) |
| | |
| | def forward(self, embedding, compute_loss=True, return_dict=True): |
| | outputs = [] |
| | losses = None |
| |
|
| | if compute_loss: |
| | losses = [] |
| | emb_sims = cosine_pairwise(embedding) |
| |
|
| | for head in self.heads: |
| | compressed_embedding = head(embedding) |
| | outputs.append(compressed_embedding) |
| |
|
| | if compute_loss: |
| | comp_sims = cosine_pairwise(compressed_embedding) |
| | loss = loss_function(emb_sims, comp_sims, self.config.loss_k_vals) |
| | losses.append(loss) |
| |
|
| | loss = torch.cat(losses).sum() |
| |
|
| | if not return_dict: |
| | return (loss, losses, embedding, outputs) |
| | |
| | return CompressionModelOutput(loss=loss, |
| | losses=losses, |
| | base_embedding=embedding, |
| | compressed_embeddings=outputs) |
| |
|
| |
|
| |
|
| |
|
| |
|