| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from .utils import nearest_power_of_two |
| | from flashfftconv import FlashFFTConv |
| |
|
| |
|
| | def convolve(u: torch.Tensor, v: torch.Tensor, n: int, use_approx: bool = True) -> tuple[torch.Tensor, torch.Tensor]: |
| | bsz, seq_len, d_in = u.shape |
| |
|
| | sgn = torch.full((1, seq_len, 1), 1, device=u.device, dtype=torch.float32) |
| | sgn[:, 1::2] *= -1 |
| |
|
| | |
| | u = u.to(torch.float32) |
| | v = v.to(torch.float32) |
| |
|
| | if use_approx: |
| | _, d_out = v.shape |
| | v = v.view(1, -1, d_out, 1) |
| | else: |
| | _, K = v.shape |
| | sgn = sgn.unsqueeze(-1) |
| | v = v.view(1, -1, K, 1, 1) |
| | u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in) |
| |
|
| | v = torch.fft.rfft(v, n=n, dim=1) |
| | U = torch.stack([u, u * sgn], dim=-1) |
| | U = torch.fft.rfft(U, n=n, dim=1) |
| | U_conv = torch.fft.irfft(v * U, n=n, dim=1)[:, :seq_len] |
| | U_plus, U_minus = torch.unbind(U_conv, dim=-1) |
| | U_minus = U_minus * sgn |
| |
|
| | |
| | U_plus = U_plus.to(u.dtype) |
| | U_minus = U_minus.to(u.dtype) |
| |
|
| | return U_plus, U_minus |
| |
|
| | def flash_convolve( |
| | u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | dtype = u.dtype |
| | u = u.to(torch.float32) |
| | v = v.to(torch.float32) |
| |
|
| | bsz, seq_len, d_in = u.shape |
| | _, K = v.shape |
| |
|
| | padded_len = nearest_power_of_two(seq_len, round_up=True) |
| | pad_len = padded_len - seq_len |
| |
|
| | sgn = torch.full((1, 1, padded_len), 1, device=u.device, dtype=torch.float32) |
| | sgn[:, :, 1::2] = -1 |
| |
|
| | if use_approx: |
| | u_padded = F.pad(u.transpose(1, 2), (0, pad_len)).contiguous() |
| | v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).contiguous() |
| | u_conv = torch.stack([u_padded, u_padded * sgn], dim=0).reshape(2 * bsz, d_in, padded_len) |
| | else: |
| | u_k_padded = F.pad(u.transpose(1, 2), (0, pad_len)).repeat_interleave(K, dim=1).contiguous() |
| | v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).repeat(d_in, 1).contiguous() |
| | u_conv = torch.stack([u_k_padded, u_k_padded * sgn], dim=0).reshape(2 * bsz, K * d_in, padded_len) |
| |
|
| | U_conv = flash_fft(u_conv, v_padded) |
| |
|
| | |
| | U_conv = U_conv[..., :seq_len] |
| |
|
| | u_plus, u_minus = torch.chunk(U_conv, 2, dim=0) |
| |
|
| | if use_approx: |
| | u_minus = u_minus * sgn[:, :, :seq_len] |
| | U_plus, U_minus = u_plus.transpose(1, 2), u_minus.transpose(1, 2) |
| | else: |
| | sgn = sgn[:, :, :seq_len].unsqueeze(-1).transpose(1, 2) |
| | U_plus = u_plus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() |
| | U_minus = u_minus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() * sgn |
| |
|
| | |
| | U_plus = U_plus.to(dtype) |
| | U_minus = U_minus.to(dtype) |
| |
|
| | return U_plus, U_minus |
| |
|