MOSS-SoundEffect / inference_utils.py
YWMditto's picture
OpenMOSS, MOSI.AI, and MOSS-TTS have been officially open-sourced!
5b37e97
import torch
import torchaudio
import torch.nn.functional as F
from typing import Optional, List, Tuple
from tqdm import tqdm
def apply_top_k(logits, top_k):
batch_size, vocab_size = logits.shape
top_k = min(top_k, vocab_size)
top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1)
filtered_logits = torch.full_like(logits, float("-inf"))
batch_indices = torch.arange(batch_size).unsqueeze(-1)
filtered_logits[batch_indices, top_k_indices] = top_k_values
return filtered_logits
def apply_top_p(logits, top_p):
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
batch_size = logits.shape[0]
filtered_logits = logits.clone()
for i in range(batch_size):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
filtered_logits[i, indices_to_remove] = float("-inf")
return filtered_logits
def apply_top_p_optimized(logits, top_p):
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = float("-inf")
return logits
def apply_repetition_penalty_delay_pattern(
logits: torch.Tensor,
prev_tokens: torch.LongTensor,
penalty: float,
):
"""
logits: [B, H, V] or [N, V]
prev_tokens: [B, T, H] or [N, T] or [B, H]
Apply the repetition penalty independently for each H (VQ head).
"""
if penalty == 1.0 or prev_tokens is None:
return logits
vocab_size = logits.size(-1)
# Case 1: regular [N, V] (text layer)
if logits.dim() == 2:
prev_tokens_flat = prev_tokens.reshape(-1)
unique_tokens = torch.unique(prev_tokens_flat)
token_logits = logits[:, unique_tokens]
pos_mask = token_logits > 0
token_logits[pos_mask] /= penalty
token_logits[~pos_mask] *= penalty
logits[:, unique_tokens] = token_logits
return logits
# Case 2: Delay Pattern audio [B, H, V]
assert logits.dim() == 3, "Delay Pattern audio logits must be [B, H, V]"
B, H, V = logits.shape
for h in range(H):
# prev_tokens_h: [B, T] or [B]
prev_tokens_h = prev_tokens[..., h].reshape(-1)
unique_tokens = torch.unique(prev_tokens_h)
if unique_tokens.numel() == 0:
continue
token_logits = logits[:, h, unique_tokens]
pos_mask = token_logits > 0
token_logits[pos_mask] /= penalty
token_logits[~pos_mask] *= penalty
logits[:, h, unique_tokens] = token_logits
return logits
def sample_token(
logits,
prev_tokens: Optional[torch.LongTensor] = None,
repetition_penalty: float = 1.0,
top_p=None,
top_k=None,
do_sample=True,
):
vocab_size = logits.size(-1)
# ===== Repetition Penalty (before reshaping!) =====
if prev_tokens is not None and repetition_penalty != 1.0:
logits = apply_repetition_penalty_delay_pattern(
logits,
prev_tokens,
repetition_penalty,
)
if not do_sample:
return torch.argmax(logits, dim=-1)
# ===== Only flatten after this, for top-k / top-p / multinomial =====
original_shape = logits.shape
reshaped_logits = logits.view(-1, vocab_size)
if top_k is not None and top_k > 0:
reshaped_logits = apply_top_k(reshaped_logits, top_k)
if top_p is not None and top_p < 1.0:
reshaped_logits = apply_top_p_optimized(reshaped_logits, top_p)
probs = F.softmax(reshaped_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1)
return next_tokens.view(original_shape[:-1])
def find_last_equal_C(tensor, C):
"""
tensor: torch.Tensor of shape [batch_size, seq_len]
C: scalar value to match
Returns: torch.Tensor of shape [batch_size] with last indices
"""
mask = (tensor == C).int() # Shape: [batch_size, seq_len], bool tensor
flipped_mask = mask.flip(dims=[1]) # Flip along sequence dimension
flipped_indices = flipped_mask.argmax(dim=1) # First True in flipped
seq_len = tensor.shape[1]
last_indices = (seq_len - 1) - flipped_indices # Convert to original indices
# Optional: Handle cases with no C (set to -1), though problem assumes existence
actual_values = tensor[torch.arange(tensor.shape[0]), last_indices]
no_match = actual_values != C
last_indices[no_match] = -1
return last_indices