Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,7 @@ from scipy.signal import resample
|
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
|
|
|
|
| 9 |
# =============================
|
| 10 |
# SIMPLE DACVAE WRAPPER
|
| 11 |
# =============================
|
|
@@ -30,11 +31,13 @@ class SimpleDACCodec:
|
|
| 30 |
|
| 31 |
@torch.inference_mode()
|
| 32 |
def encode(self, audio):
|
| 33 |
-
|
| 34 |
-
|
|
|
|
| 35 |
|
| 36 |
@torch.inference_mode()
|
| 37 |
def decode(self, latent):
|
|
|
|
| 38 |
z = latent.transpose(1, 2)
|
| 39 |
return self.model.decode(z)
|
| 40 |
|
|
@@ -52,6 +55,7 @@ codec = SimpleDACCodec.load(device=DEVICE)
|
|
| 52 |
def load_audio(path):
|
| 53 |
audio, sr = sf.read(path, dtype="float32")
|
| 54 |
|
|
|
|
| 55 |
if audio.ndim > 1:
|
| 56 |
audio = np.mean(audio, axis=1)
|
| 57 |
|
|
@@ -75,63 +79,72 @@ def to_tensor(audio):
|
|
| 75 |
# =============================
|
| 76 |
def encode_audio(file):
|
| 77 |
if file is None:
|
| 78 |
-
raise
|
| 79 |
|
| 80 |
audio, sr = load_audio(file)
|
| 81 |
audio = resample_audio(audio, sr, codec.sample_rate)
|
| 82 |
-
|
| 83 |
wav = to_tensor(audio).to(DEVICE)
|
| 84 |
-
latent = codec.encode(wav)
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
# =============================
|
| 91 |
# DECODE
|
| 92 |
# =============================
|
| 93 |
-
def decode_audio(
|
| 94 |
-
if
|
| 95 |
-
raise
|
| 96 |
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
if latent.ndim == 2:
|
| 100 |
latent = latent.unsqueeze(0)
|
| 101 |
|
| 102 |
audio = codec.decode(latent)
|
| 103 |
-
audio = audio.squeeze().cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
return (codec.sample_rate, audio)
|
| 106 |
|
| 107 |
|
| 108 |
# =============================
|
| 109 |
-
# UI
|
| 110 |
# =============================
|
| 111 |
with gr.Blocks() as demo:
|
| 112 |
-
gr.Markdown("## π§ Simple DAC Audio Codec")
|
| 113 |
-
|
| 114 |
-
audio_in = gr.Audio(type="filepath", label="Upload Audio")
|
| 115 |
|
| 116 |
-
|
| 117 |
-
decode_btn = gr.Button("Decode")
|
| 118 |
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
# Encode β store in state
|
| 124 |
encode_btn.click(
|
| 125 |
fn=encode_audio,
|
| 126 |
inputs=audio_in,
|
| 127 |
-
outputs=latent_state
|
| 128 |
)
|
| 129 |
|
| 130 |
-
# Decode from state
|
| 131 |
decode_btn.click(
|
| 132 |
fn=decode_audio,
|
| 133 |
inputs=latent_state,
|
| 134 |
-
outputs=audio_out
|
| 135 |
)
|
| 136 |
|
| 137 |
|
|
|
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
|
| 9 |
+
|
| 10 |
# =============================
|
| 11 |
# SIMPLE DACVAE WRAPPER
|
| 12 |
# =============================
|
|
|
|
| 31 |
|
| 32 |
@torch.inference_mode()
|
| 33 |
def encode(self, audio):
|
| 34 |
+
# audio: (1, 1, T)
|
| 35 |
+
z = self.model.encode(audio) # (B, D, T)
|
| 36 |
+
return z.transpose(1, 2) # (B, T, D)
|
| 37 |
|
| 38 |
@torch.inference_mode()
|
| 39 |
def decode(self, latent):
|
| 40 |
+
# latent: (B, T, D)
|
| 41 |
z = latent.transpose(1, 2)
|
| 42 |
return self.model.decode(z)
|
| 43 |
|
|
|
|
| 55 |
def load_audio(path):
|
| 56 |
audio, sr = sf.read(path, dtype="float32")
|
| 57 |
|
| 58 |
+
# mono
|
| 59 |
if audio.ndim > 1:
|
| 60 |
audio = np.mean(audio, axis=1)
|
| 61 |
|
|
|
|
| 79 |
# =============================
|
| 80 |
def encode_audio(file):
|
| 81 |
if file is None:
|
| 82 |
+
raise ValueError("Please upload an audio file first.")
|
| 83 |
|
| 84 |
audio, sr = load_audio(file)
|
| 85 |
audio = resample_audio(audio, sr, codec.sample_rate)
|
|
|
|
| 86 |
wav = to_tensor(audio).to(DEVICE)
|
|
|
|
| 87 |
|
| 88 |
+
latent = codec.encode(wav) # (B, T, D)
|
| 89 |
+
|
| 90 |
+
latent_list = latent.detach().cpu().numpy().tolist()
|
| 91 |
+
return latent_list, latent_list # one for display, one for hidden state
|
| 92 |
|
| 93 |
|
| 94 |
# =============================
|
| 95 |
# DECODE
|
| 96 |
# =============================
|
| 97 |
+
def decode_audio(latent_list):
|
| 98 |
+
if latent_list is None:
|
| 99 |
+
raise ValueError("No latent found. Click Encode first.")
|
| 100 |
|
| 101 |
+
# Convert nested list to tensor safely
|
| 102 |
+
try:
|
| 103 |
+
latent = torch.tensor(latent_list, dtype=torch.float32, device=DEVICE)
|
| 104 |
+
except Exception as e:
|
| 105 |
+
raise ValueError(f"Invalid latent data: {e}")
|
| 106 |
|
| 107 |
if latent.ndim == 2:
|
| 108 |
latent = latent.unsqueeze(0)
|
| 109 |
|
| 110 |
audio = codec.decode(latent)
|
| 111 |
+
audio = audio.squeeze().detach().cpu().numpy()
|
| 112 |
+
|
| 113 |
+
# clip just in case
|
| 114 |
+
audio = np.nan_to_num(audio)
|
| 115 |
+
audio = np.clip(audio, -1.0, 1.0)
|
| 116 |
|
| 117 |
return (codec.sample_rate, audio)
|
| 118 |
|
| 119 |
|
| 120 |
# =============================
|
| 121 |
+
# UI
|
| 122 |
# =============================
|
| 123 |
with gr.Blocks() as demo:
|
| 124 |
+
gr.Markdown("## π§ Simple DAC Audio Codec (Single Window)")
|
|
|
|
|
|
|
| 125 |
|
| 126 |
+
latent_state = gr.State()
|
|
|
|
| 127 |
|
| 128 |
+
with gr.Row():
|
| 129 |
+
with gr.Column(scale=1):
|
| 130 |
+
audio_in = gr.Audio(type="filepath", label="Upload Audio")
|
| 131 |
+
encode_btn = gr.Button("Encode")
|
| 132 |
+
decode_btn = gr.Button("Decode")
|
| 133 |
|
| 134 |
+
with gr.Column(scale=1):
|
| 135 |
+
latent_out = gr.JSON(label="Latent")
|
| 136 |
+
audio_out = gr.Audio(label="Reconstructed Audio")
|
| 137 |
|
|
|
|
| 138 |
encode_btn.click(
|
| 139 |
fn=encode_audio,
|
| 140 |
inputs=audio_in,
|
| 141 |
+
outputs=[latent_out, latent_state],
|
| 142 |
)
|
| 143 |
|
|
|
|
| 144 |
decode_btn.click(
|
| 145 |
fn=decode_audio,
|
| 146 |
inputs=latent_state,
|
| 147 |
+
outputs=audio_out,
|
| 148 |
)
|
| 149 |
|
| 150 |
|