vae / app.py
Humair332's picture
Update app.py
7140878 verified
import gradio as gr
import torch
import numpy as np
import soundfile as sf
from scipy.signal import resample as scipy_resample
from dataclasses import dataclass, field
from huggingface_hub import hf_hub_download
import time
import json
# =============================
# DACVAE WRAPPER
# =============================
@dataclass
class SimpleDACCodec:
model: torch.nn.Module
sample_rate: int
hop_size: int # encoder stride in samples β€” probed at load time
device: torch.device
@classmethod
def load(cls, repo_id="Aratako/Semantic-DACVAE-Japanese-32dim", device="cpu"):
from dacvae import DACVAE
weights_path = hf_hub_download(repo_id=repo_id, filename="weights.pth")
model = DACVAE.load(weights_path).eval().to(device)
sr = int(model.sample_rate)
# ── Probe the real hop size ───────────────────────────────────────────
# We feed a known-length signal and measure how many frames come out.
# This is the only correct way β€” no magic constants needed.
# hop = input_samples / output_frames (for a signal long enough to
# avoid edge effects we use 1 second = sr samples)
probe_len = sr # exactly 1 second of silence
dummy = torch.zeros(1, 1, probe_len, device=device,
dtype=next(model.parameters()).dtype)
with torch.inference_mode():
z = model.encode(dummy) # (1, D, T_latent)
t_latent = z.shape[2]
hop = probe_len // t_latent # integer hop in samples
print(f"[codec] sample_rate={sr} probe_frames={t_latent} "
f"hop={hop} frame_rate={sr/hop:.4f} Hz", flush=True)
return cls(
model = model,
sample_rate = sr,
hop_size = hop,
device = torch.device(device),
)
@property
def frame_rate(self) -> float:
"""Latent frames per second."""
return self.sample_rate / self.hop_size
def frames_to_seconds(self, num_frames: int) -> float:
"""Convert latent frame count -> audio duration in seconds."""
return num_frames * self.hop_size / self.sample_rate
@torch.inference_mode()
def encode(self, audio: torch.Tensor) -> torch.Tensor:
"""audio: (1, 1, T) -> latent: (1, T_latent, D)"""
z = self.model.encode(audio) # (B, D, T)
return z.transpose(1, 2) # (B, T, D)
@torch.inference_mode()
def decode(self, latent: torch.Tensor) -> torch.Tensor:
"""latent: (B, T_latent, D) -> audio: (B, 1, T)"""
return self.model.decode(latent.transpose(1, 2))
# =============================
# INIT
# =============================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[init] Using device: {DEVICE}")
codec = SimpleDACCodec.load(device=DEVICE)
print(f"[init] Codec ready. Frame rate = {codec.frame_rate:.4f} Hz "
f"(hop={codec.hop_size}, sr={codec.sample_rate})")
# =============================
# AUDIO UTILS
# =============================
def load_audio(path: str) -> tuple[np.ndarray, int]:
audio, sr = sf.read(path, dtype="float32")
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
return audio, sr
def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
if orig_sr == target_sr:
return audio
num_samples = int(len(audio) * target_sr / orig_sr)
return scipy_resample(audio, num_samples)
def to_tensor(audio: np.ndarray) -> torch.Tensor:
return torch.from_numpy(audio).unsqueeze(0).unsqueeze(0) # (1, 1, T)
def format_stats(stats: dict) -> str:
"""Render stats dict as a clean markdown table for display."""
lines = ["| Property | Value |", "|---|---|"]
for k, v in stats.items():
lines.append(f"| {k} | `{v}` |")
return "\n".join(lines)
# =============================
# ENCODE
# =============================
def encode_audio(file):
if file is None:
return None, None, "⚠️ Please upload an audio file first."
t0 = time.perf_counter()
# Load + resample
audio_orig, sr_orig = load_audio(file)
orig_samples = len(audio_orig)
orig_duration = orig_samples / sr_orig
audio_resampled = resample_audio(audio_orig, sr_orig, codec.sample_rate)
resampled_samples = len(audio_resampled)
wav = to_tensor(audio_resampled).to(DEVICE)
# Encode
latent = codec.encode(wav) # (1, T_latent, D)
t_enc = time.perf_counter() - t0
num_frames = latent.shape[1]
latent_dim = latent.shape[2]
calc_dur = codec.frames_to_seconds(num_frames)
latent_np = latent.squeeze(0).detach().cpu().numpy() # (T, D)
latent_list = latent_np.tolist()
# Stats
stats = {
"πŸ“ Original sample rate": f"{sr_orig} Hz",
"🎡 Codec sample rate": f"{codec.sample_rate} Hz",
"⏱ Original duration": f"{orig_duration:.4f} s ({orig_samples:,} samples)",
"⏱ Resampled duration": f"{resampled_samples / codec.sample_rate:.4f} s ({resampled_samples:,} samples)",
"πŸ”’ Latent frames (T)": f"{num_frames}",
"πŸ“ Latent dim (D)": f"{latent_dim}",
"πŸ“ Encoder hop size": f"{codec.hop_size} samples",
"πŸ”„ Latent frame rate": f"{codec.frame_rate:.4f} Hz",
"⏳ Duration from latent": f"{calc_dur:.4f} s (T Γ— hop / sr = {num_frames} Γ— {codec.hop_size} / {codec.sample_rate})",
"βœ… Duration match": f"{'βœ“ exact' if abs(calc_dur - resampled_samples / codec.sample_rate) < 0.05 else '⚠ mismatch'}",
"⚑ Encode time": f"{t_enc*1000:.1f} ms",
"πŸ’Ύ Latent tensor size": f"{latent_np.nbytes / 1024:.1f} KB (float32)",
"πŸ“Š Latent value range": f"[{latent_np.min():.4f}, {latent_np.max():.4f}]",
"πŸ“Š Latent mean / std": f"{latent_np.mean():.4f} / {latent_np.std():.4f}",
}
stats_md = format_stats(stats)
return latent_list, latent_list, stats_md
# =============================
# DECODE
# =============================
def decode_audio(latent_list, stats_md_current):
if latent_list is None:
return None, (stats_md_current or "") + "\n\n⚠️ No latent found. Encode first."
t0 = time.perf_counter()
try:
latent = torch.tensor(latent_list, dtype=torch.float32, device=DEVICE)
except Exception as e:
return None, f"⚠️ Invalid latent: {e}"
if latent.ndim == 2:
latent = latent.unsqueeze(0) # (1, T, D)
audio = codec.decode(latent) # (B, 1, T_out)
t_dec = time.perf_counter() - t0
audio_np = audio.squeeze().detach().cpu().numpy()
audio_np = np.nan_to_num(audio_np)
audio_np = np.clip(audio_np, -1.0, 1.0)
num_frames = latent.shape[1]
out_samples = len(audio_np)
actual_dur = out_samples / codec.sample_rate
calc_dur = codec.frames_to_seconds(num_frames)
actual_hop = out_samples // num_frames
decode_stats = {
"πŸ”’ Latent frames decoded": f"{num_frames}",
"πŸ”Š Output samples": f"{out_samples:,}",
"⏱ Reconstructed duration": f"{actual_dur:.4f} s",
"⏳ Duration from latent": f"{calc_dur:.4f} s",
"πŸ” Actual output hop": f"{actual_hop} samples/frame (expected {codec.hop_size})",
"βœ… Formula confirmation": f"T={num_frames} Γ— hop={actual_hop} / sr={codec.sample_rate} = {num_frames * actual_hop / codec.sample_rate:.4f} s",
"⚑ Decode time": f"{t_dec*1000:.1f} ms",
"πŸ“Š Output value range": f"[{audio_np.min():.4f}, {audio_np.max():.4f}]",
}
decode_md = format_stats(decode_stats)
combined = (stats_md_current or "") + "\n\n### Decode Stats\n" + decode_md
return (codec.sample_rate, audio_np), combined
# =============================
# UI
# =============================
css = """
body, .gradio-container {
background: #0d0d0d !important;
font-family: 'IBM Plex Mono', monospace !important;
color: #e0e0e0 !important;
}
h1, h2, h3 { color: #00e5a0 !important; letter-spacing: 0.08em; }
.gr-button {
background: #00e5a0 !important;
color: #000 !important;
font-weight: 700 !important;
border-radius: 2px !important;
border: none !important;
font-family: 'IBM Plex Mono', monospace !important;
letter-spacing: 0.05em;
}
.gr-button:hover { background: #00ffa8 !important; }
.gr-box, .gr-panel { background: #151515 !important; border: 1px solid #2a2a2a !important; }
table { width: 100%; border-collapse: collapse; font-size: 0.82em; }
th { color: #00e5a0; border-bottom: 1px solid #2a2a2a; padding: 4px 8px; text-align: left; }
td { padding: 4px 8px; border-bottom: 1px solid #1a1a1a; }
td code { background: #1e1e1e; padding: 2px 6px; border-radius: 2px; color: #a8ff78; }
"""
with gr.Blocks(css=css, title="DACVAE Inspector") as demo:
gr.HTML("""
<link href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;700&display=swap" rel="stylesheet">
<div style="padding: 24px 0 8px 0;">
<h1 style="font-size:1.6em; margin:0; letter-spacing:0.12em;">
β—ˆ DACVAE CODEC INSPECTOR
</h1>
<p style="color:#666; margin:4px 0 0 0; font-size:0.78em; letter-spacing:0.06em;">
Aratako/Semantic-DACVAE-Japanese-32dim &nbsp;Β·&nbsp;
sr={sr} Hz &nbsp;Β·&nbsp; hop={hop} &nbsp;Β·&nbsp; frame_rate={fr:.4f} Hz
</p>
</div>
""".format(sr=codec.sample_rate, hop=codec.hop_size, fr=codec.frame_rate))
latent_state = gr.State()
with gr.Row():
# ── Left column ───────────────────────────────
with gr.Column(scale=1):
audio_in = gr.Audio(type="filepath", label="Input Audio")
with gr.Row():
encode_btn = gr.Button("β–Ά ENCODE", variant="primary")
decode_btn = gr.Button("β—€ DECODE", variant="primary")
audio_out = gr.Audio(label="Reconstructed Audio", interactive=False)
# ── Right column ──────────────────────────────
with gr.Column(scale=1):
stats_out = gr.Markdown(
value="*Stats will appear here after encoding.*",
label="Stats"
)
with gr.Accordion("Raw Latent JSON (first 3 frames)", open=False):
latent_preview = gr.JSON(label="Latent preview")
# ── Wire up ───────────────────────────────────────
def encode_and_preview(file):
latent_list, _, stats_md = encode_audio(file)
if latent_list is None:
return None, None, stats_md
preview = latent_list[:3] if latent_list else []
return latent_list, preview, stats_md
encode_btn.click(
fn=encode_and_preview,
inputs=audio_in,
outputs=[latent_state, latent_preview, stats_out],
)
decode_btn.click(
fn=decode_audio,
inputs=[latent_state, stats_out],
outputs=[audio_out, stats_out],
)
# =============================
# RUN
# =============================
if __name__ == "__main__":
demo.launch(share=True)