Instructions to use FormalZz/AudioX with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Stable Audio Tools
How to use FormalZz/AudioX with Stable Audio Tools:
import torch import torchaudio from einops import rearrange from stable_audio_tools import get_pretrained_model from stable_audio_tools.inference.generation import generate_diffusion_cond device = "cuda" if torch.cuda.is_available() else "cpu" # Download model model, model_config = get_pretrained_model("FormalZz/AudioX") sample_rate = model_config["sample_rate"] sample_size = model_config["sample_size"] model = model.to(device) # Set up text and timing conditioning conditioning = [{ "prompt": "128 BPM tech house drum loop", }] # Generate stereo audio output = generate_diffusion_cond( model, conditioning=conditioning, sample_size=sample_size, device=device ) # Rearrange audio batch to a single sequence output = rearrange(output, "b d n -> d (b n)") # Peak normalize, clip, convert to int16, and save to file output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() torchaudio.save("output.wav", output, sample_rate) - Notebooks
- Google Colab
- Kaggle
File size: 3,240 Bytes
c062bb6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | """The 1D discrete wavelet transform for PyTorch."""
from einops import rearrange
import pywt
import torch
from torch import nn
from torch.nn import functional as F
from typing import Literal
def get_filter_bank(wavelet):
filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank)
if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0):
filt = filt[:, 1:]
return filt
class WaveletEncode1d(nn.Module):
def __init__(self,
channels,
levels,
wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
super().__init__()
self.wavelet = wavelet
self.channels = channels
self.levels = levels
filt = get_filter_bank(wavelet)
assert filt.shape[-1] % 2 == 1
kernel = filt[:2, None]
kernel = torch.flip(kernel, dims=(-1,))
index_i = torch.repeat_interleave(torch.arange(2), channels)
index_j = torch.tile(torch.arange(channels), (2,))
kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
self.register_buffer("kernel", kernel_final)
def forward(self, x):
for i in range(self.levels):
low, rest = x[:, : self.channels], x[:, self.channels :]
pad = self.kernel.shape[-1] // 2
low = F.pad(low, (pad, pad), "reflect")
low = F.conv1d(low, self.kernel, stride=2)
rest = rearrange(
rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels
)
x = torch.cat([low, rest], dim=1)
return x
class WaveletDecode1d(nn.Module):
def __init__(self,
channels,
levels,
wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
super().__init__()
self.wavelet = wavelet
self.channels = channels
self.levels = levels
filt = get_filter_bank(wavelet)
assert filt.shape[-1] % 2 == 1
kernel = filt[2:, None]
index_i = torch.repeat_interleave(torch.arange(2), channels)
index_j = torch.tile(torch.arange(channels), (2,))
kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
self.register_buffer("kernel", kernel_final)
def forward(self, x):
for i in range(self.levels):
low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :]
pad = self.kernel.shape[-1] // 2 + 2
low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2)
low = F.pad(low, (pad, pad), "reflect")
low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2)
low = F.conv_transpose1d(
low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2
)
low = low[..., pad - 1 : -pad]
rest = rearrange(
rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels
)
x = torch.cat([low, rest], dim=1)
return x |