|
|
import torch |
|
|
import torchaudio |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional, List, Tuple |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
def apply_top_k(logits, top_k): |
|
|
batch_size, vocab_size = logits.shape |
|
|
top_k = min(top_k, vocab_size) |
|
|
top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) |
|
|
filtered_logits = torch.full_like(logits, float("-inf")) |
|
|
batch_indices = torch.arange(batch_size).unsqueeze(-1) |
|
|
filtered_logits[batch_indices, top_k_indices] = top_k_values |
|
|
return filtered_logits |
|
|
|
|
|
|
|
|
def apply_top_p(logits, top_p): |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) |
|
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = False |
|
|
batch_size = logits.shape[0] |
|
|
filtered_logits = logits.clone() |
|
|
for i in range(batch_size): |
|
|
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] |
|
|
filtered_logits[i, indices_to_remove] = float("-inf") |
|
|
return filtered_logits |
|
|
|
|
|
|
|
|
def apply_top_p_optimized(logits, top_p): |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) |
|
|
|
|
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = False |
|
|
|
|
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_( |
|
|
dim=-1, index=sorted_indices, src=sorted_indices_to_remove |
|
|
) |
|
|
|
|
|
logits[indices_to_remove] = float("-inf") |
|
|
return logits |
|
|
|
|
|
|
|
|
def apply_repetition_penalty_delay_pattern( |
|
|
logits: torch.Tensor, |
|
|
prev_tokens: torch.LongTensor, |
|
|
penalty: float, |
|
|
): |
|
|
""" |
|
|
logits: [B, H, V] or [N, V] |
|
|
prev_tokens: [B, T, H] or [N, T] or [B, H] |
|
|
|
|
|
Apply the repetition penalty independently for each H (VQ head). |
|
|
""" |
|
|
if penalty == 1.0 or prev_tokens is None: |
|
|
return logits |
|
|
|
|
|
vocab_size = logits.size(-1) |
|
|
|
|
|
|
|
|
if logits.dim() == 2: |
|
|
prev_tokens_flat = prev_tokens.reshape(-1) |
|
|
unique_tokens = torch.unique(prev_tokens_flat) |
|
|
|
|
|
token_logits = logits[:, unique_tokens] |
|
|
pos_mask = token_logits > 0 |
|
|
token_logits[pos_mask] /= penalty |
|
|
token_logits[~pos_mask] *= penalty |
|
|
logits[:, unique_tokens] = token_logits |
|
|
return logits |
|
|
|
|
|
|
|
|
assert logits.dim() == 3, "Delay Pattern audio logits must be [B, H, V]" |
|
|
B, H, V = logits.shape |
|
|
|
|
|
for h in range(H): |
|
|
|
|
|
prev_tokens_h = prev_tokens[..., h].reshape(-1) |
|
|
unique_tokens = torch.unique(prev_tokens_h) |
|
|
|
|
|
if unique_tokens.numel() == 0: |
|
|
continue |
|
|
|
|
|
token_logits = logits[:, h, unique_tokens] |
|
|
pos_mask = token_logits > 0 |
|
|
token_logits[pos_mask] /= penalty |
|
|
token_logits[~pos_mask] *= penalty |
|
|
logits[:, h, unique_tokens] = token_logits |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
def sample_token( |
|
|
logits, |
|
|
prev_tokens: Optional[torch.LongTensor] = None, |
|
|
repetition_penalty: float = 1.0, |
|
|
top_p=None, |
|
|
top_k=None, |
|
|
do_sample=True, |
|
|
): |
|
|
vocab_size = logits.size(-1) |
|
|
|
|
|
|
|
|
if prev_tokens is not None and repetition_penalty != 1.0: |
|
|
logits = apply_repetition_penalty_delay_pattern( |
|
|
logits, |
|
|
prev_tokens, |
|
|
repetition_penalty, |
|
|
) |
|
|
|
|
|
if not do_sample: |
|
|
return torch.argmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
original_shape = logits.shape |
|
|
reshaped_logits = logits.view(-1, vocab_size) |
|
|
|
|
|
if top_k is not None and top_k > 0: |
|
|
reshaped_logits = apply_top_k(reshaped_logits, top_k) |
|
|
|
|
|
if top_p is not None and top_p < 1.0: |
|
|
reshaped_logits = apply_top_p_optimized(reshaped_logits, top_p) |
|
|
|
|
|
probs = F.softmax(reshaped_logits, dim=-1) |
|
|
next_tokens = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
return next_tokens.view(original_shape[:-1]) |
|
|
|
|
|
|
|
|
def find_last_equal_C(tensor, C): |
|
|
""" |
|
|
tensor: torch.Tensor of shape [batch_size, seq_len] |
|
|
C: scalar value to match |
|
|
Returns: torch.Tensor of shape [batch_size] with last indices |
|
|
""" |
|
|
mask = (tensor == C).int() |
|
|
flipped_mask = mask.flip(dims=[1]) |
|
|
flipped_indices = flipped_mask.argmax(dim=1) |
|
|
seq_len = tensor.shape[1] |
|
|
last_indices = (seq_len - 1) - flipped_indices |
|
|
|
|
|
|
|
|
actual_values = tensor[torch.arange(tensor.shape[0]), last_indices] |
|
|
no_match = actual_values != C |
|
|
last_indices[no_match] = -1 |
|
|
|
|
|
return last_indices |
|
|
|