""" This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. See the original Unsloth repository at https://github.com/unslothai/unsloth. The idea of in-place backward pass is from Liger-Kernel. See the original Liger-Kernel repository at https://github.com/linkedin/Liger-Kernel. """ import torch import torch.nn as nn import triton import triton.language as tl # Reference implementation @torch.compile(dynamic=None) def _compute_loss(logits, target_p, position_mask): logits = logits.float() out_logp = nn.LogSoftmax(dim=2)(logits) plogp = target_p * out_logp loss = -torch.sum(position_mask * plogp, 2).mean() return loss def _calculate_settings(n): # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 MAX_FUSED_SIZE = 131072 BLOCK_SIZE = triton.next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: raise RuntimeError( f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}." ) num_warps = 4 if BLOCK_SIZE >= 32768: num_warps = 32 elif BLOCK_SIZE >= 8192: num_warps = 16 elif BLOCK_SIZE >= 2048: num_warps = 8 # AMD GPU (ROCm) if hasattr(torch.version, "hip") and torch.version.hip is not None: num_warps //= 2 return BLOCK_SIZE, num_warps @triton.jit def log_softmax_forward_kernel( logits_ptr, logits_stride, target_ptr, target_stride, position_mask_ptr, position_mask_stride, loss_ptr, loss_stride, m_ptr, d_ptr, n_cols, BLOCK_SIZE: tl.constexpr, ): program_id = tl.program_id(0).to(tl.int64) logits_ptr += program_id * logits_stride target_ptr += program_id * target_stride position_mask_ptr += program_id * position_mask_stride position_mask = tl.load(position_mask_ptr) if position_mask == 0: return m = float("-inf") d = 0.0 for i in range(0, n_cols, BLOCK_SIZE): offsets = i + tl.arange(0, BLOCK_SIZE) mask = offsets < n_cols logits_block = tl.load( logits_ptr + offsets, mask=mask, other=float("-inf") ).cast(tl.float32) block_max = tl.max(tl.where(mask, logits_block, float("-inf"))) m_new = tl.maximum(m, block_max) d = d * tl.exp(m - m_new) + tl.sum( tl.where(mask, tl.exp(logits_block - m_new), 0.0) ) m = m_new loss = 0.0 for i in range(0, n_cols, BLOCK_SIZE): offsets = i + tl.arange(0, BLOCK_SIZE) mask = offsets < n_cols logits_block = tl.load(logits_ptr + offsets, mask=mask, other=0.0).cast( tl.float32 ) target_block = tl.load(target_ptr + offsets, mask=mask, other=0.0).cast( tl.float32 ) # log-softmax: log(exp(x - max) / sum) = (x - max) - log(sum) normalized_logits = logits_block - m log_normalizer = tl.log(d) log_softmax_logits = normalized_logits - log_normalizer weighted_log_prob = target_block * log_softmax_logits loss += tl.sum(tl.where(mask, weighted_log_prob, 0.0)) loss_ptr += program_id * loss_stride m_ptr += program_id d_ptr += program_id tl.store(loss_ptr, -loss) tl.store(m_ptr, m.to(tl.float32)) tl.store(d_ptr, d.to(tl.float32)) @triton.jit def log_softmax_backward_kernel( logits_ptr, logits_stride, target_ptr, target_stride, position_mask_ptr, grad_output_ptr, scaling_factor, m_ptr, d_ptr, n_cols, BLOCK_SIZE: tl.constexpr, ): program_id = tl.program_id(0).to(tl.int64) logits_ptr += program_id * logits_stride target_ptr += program_id * target_stride position_mask_ptr += program_id position_mask = tl.load(position_mask_ptr) if position_mask == 0: for i in range(0, n_cols, BLOCK_SIZE): offsets = i + tl.arange(0, BLOCK_SIZE) mask = offsets < n_cols tl.store(logits_ptr + offsets, 0.0, mask=mask) return m_ptr += program_id d_ptr += program_id m = tl.load(m_ptr).to(tl.float32) d = tl.load(d_ptr).to(tl.float32) grad_output = tl.load(grad_output_ptr).to(tl.float32) grad_output = grad_output * scaling_factor # First pass: compute sum of (target * grad_output) target_grad_sum = 0.0 for i in range(0, n_cols, BLOCK_SIZE): offsets = i + tl.arange(0, BLOCK_SIZE) mask = offsets < n_cols target_block = tl.load(target_ptr + offsets, mask=mask, other=0.0).cast( tl.float32 ) target_grad_sum += tl.sum(tl.where(mask, target_block * grad_output, 0.0)) # Second pass: compute log-softmax gradients for i in range(0, n_cols, BLOCK_SIZE): offsets = i + tl.arange(0, BLOCK_SIZE) mask = offsets < n_cols logits_block = tl.load(logits_ptr + offsets, mask=mask, other=0.0).cast( tl.float32 ) target_block = tl.load(target_ptr + offsets, mask=mask, other=0.0).cast( tl.float32 ) softmax_prob = tl.exp(logits_block - m) / d normalized_grad = softmax_prob * target_grad_sum grad_block = -(target_block * grad_output - normalized_grad) tl.store(logits_ptr + offsets, grad_block.to(tl.float32), mask=mask) class LogSoftmaxLoss(torch.autograd.Function): @staticmethod def forward(ctx, logits, target, position_mask): B, T, V = logits.shape loss = torch.zeros((B * T, 1), device=logits.device) logits_flat = logits.contiguous().view(B * T, V) target_flat = target.contiguous().view(B * T, V) position_mask_flat = position_mask.contiguous().view(B * T, 1).bool() grid = (B * T,) m = torch.zeros((B * T,), device=logits.device, dtype=torch.float32) d = torch.zeros((B * T,), device=logits.device, dtype=torch.float32) BLOCK_SIZE, num_warps = _calculate_settings(V) log_softmax_forward_kernel[grid]( logits_flat, logits_flat.stride(0), target_flat, target_flat.stride(0), position_mask_flat, position_mask_flat.stride(0), loss, loss.stride(0), m, d, V, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, ) ctx.save_for_backward(logits.detach(), target, position_mask, m, d) return loss.squeeze(1).mean() @staticmethod def backward(ctx, grad_output): logits, target, position_mask, m, d = ctx.saved_tensors B, T, V = logits.shape scaling_factor = 1.0 / (B * T) logits = logits.contiguous().view(B * T, V) target = target.contiguous().view(B * T, V) position_mask = position_mask.contiguous().view(B * T, 1).bool() grid = (B * T,) BLOCK_SIZE, num_warps = _calculate_settings(V) log_softmax_backward_kernel[grid]( logits, logits.stride(0), target, target.stride(0), position_mask, grad_output, scaling_factor, m, d, V, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, ) logits = logits.view(B, T, V) return logits, None, None, None, None if __name__ == "__main__": device = "cuda" B, T, V = 1, 1024, 16000 logits = torch.randn(B, T, V, device=device, requires_grad=True) logits2 = logits.clone().detach().requires_grad_(True) target = torch.randn(B, T, V, device=device) position_mask = torch.randint(0, 2, (B, T, 1), dtype=torch.bool, device=device) position_mask = torch.ones((B, T, 1), dtype=torch.bool, device=device) output1 = LogSoftmaxLoss.apply(logits, target, position_mask) output2 = _compute_loss(logits2, target, position_mask) torch.testing.assert_close(output1, output2, rtol=1e-4, atol=1e-4) output1.backward() output2.backward() torch.testing.assert_close(logits.grad, logits2.grad, rtol=1e-4, atol=1e-4)