import gradio as gr import torch import numpy as np import soundfile as sf from scipy.signal import resample as scipy_resample from dataclasses import dataclass, field from huggingface_hub import hf_hub_download import time import json # ============================= # DACVAE WRAPPER # ============================= @dataclass class SimpleDACCodec: model: torch.nn.Module sample_rate: int hop_size: int # encoder stride in samples — probed at load time device: torch.device @classmethod def load(cls, repo_id="Aratako/Semantic-DACVAE-Japanese-32dim", device="cpu"): from dacvae import DACVAE weights_path = hf_hub_download(repo_id=repo_id, filename="weights.pth") model = DACVAE.load(weights_path).eval().to(device) sr = int(model.sample_rate) # ── Probe the real hop size ─────────────────────────────────────────── # We feed a known-length signal and measure how many frames come out. # This is the only correct way — no magic constants needed. # hop = input_samples / output_frames (for a signal long enough to # avoid edge effects we use 1 second = sr samples) probe_len = sr # exactly 1 second of silence dummy = torch.zeros(1, 1, probe_len, device=device, dtype=next(model.parameters()).dtype) with torch.inference_mode(): z = model.encode(dummy) # (1, D, T_latent) t_latent = z.shape[2] hop = probe_len // t_latent # integer hop in samples print(f"[codec] sample_rate={sr} probe_frames={t_latent} " f"hop={hop} frame_rate={sr/hop:.4f} Hz", flush=True) return cls( model = model, sample_rate = sr, hop_size = hop, device = torch.device(device), ) @property def frame_rate(self) -> float: """Latent frames per second.""" return self.sample_rate / self.hop_size def frames_to_seconds(self, num_frames: int) -> float: """Convert latent frame count -> audio duration in seconds.""" return num_frames * self.hop_size / self.sample_rate @torch.inference_mode() def encode(self, audio: torch.Tensor) -> torch.Tensor: """audio: (1, 1, T) -> latent: (1, T_latent, D)""" z = self.model.encode(audio) # (B, D, T) return z.transpose(1, 2) # (B, T, D) @torch.inference_mode() def decode(self, latent: torch.Tensor) -> torch.Tensor: """latent: (B, T_latent, D) -> audio: (B, 1, T)""" return self.model.decode(latent.transpose(1, 2)) # ============================= # INIT # ============================= DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"[init] Using device: {DEVICE}") codec = SimpleDACCodec.load(device=DEVICE) print(f"[init] Codec ready. Frame rate = {codec.frame_rate:.4f} Hz " f"(hop={codec.hop_size}, sr={codec.sample_rate})") # ============================= # AUDIO UTILS # ============================= def load_audio(path: str) -> tuple[np.ndarray, int]: audio, sr = sf.read(path, dtype="float32") if audio.ndim > 1: audio = np.mean(audio, axis=1) return audio, sr def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: if orig_sr == target_sr: return audio num_samples = int(len(audio) * target_sr / orig_sr) return scipy_resample(audio, num_samples) def to_tensor(audio: np.ndarray) -> torch.Tensor: return torch.from_numpy(audio).unsqueeze(0).unsqueeze(0) # (1, 1, T) def format_stats(stats: dict) -> str: """Render stats dict as a clean markdown table for display.""" lines = ["| Property | Value |", "|---|---|"] for k, v in stats.items(): lines.append(f"| {k} | `{v}` |") return "\n".join(lines) # ============================= # ENCODE # ============================= def encode_audio(file): if file is None: return None, None, "⚠️ Please upload an audio file first." t0 = time.perf_counter() # Load + resample audio_orig, sr_orig = load_audio(file) orig_samples = len(audio_orig) orig_duration = orig_samples / sr_orig audio_resampled = resample_audio(audio_orig, sr_orig, codec.sample_rate) resampled_samples = len(audio_resampled) wav = to_tensor(audio_resampled).to(DEVICE) # Encode latent = codec.encode(wav) # (1, T_latent, D) t_enc = time.perf_counter() - t0 num_frames = latent.shape[1] latent_dim = latent.shape[2] calc_dur = codec.frames_to_seconds(num_frames) latent_np = latent.squeeze(0).detach().cpu().numpy() # (T, D) latent_list = latent_np.tolist() # Stats stats = { "📁 Original sample rate": f"{sr_orig} Hz", "🎵 Codec sample rate": f"{codec.sample_rate} Hz", "⏱ Original duration": f"{orig_duration:.4f} s ({orig_samples:,} samples)", "⏱ Resampled duration": f"{resampled_samples / codec.sample_rate:.4f} s ({resampled_samples:,} samples)", "🔢 Latent frames (T)": f"{num_frames}", "📐 Latent dim (D)": f"{latent_dim}", "📏 Encoder hop size": f"{codec.hop_size} samples", "🔄 Latent frame rate": f"{codec.frame_rate:.4f} Hz", "⏳ Duration from latent": f"{calc_dur:.4f} s (T × hop / sr = {num_frames} × {codec.hop_size} / {codec.sample_rate})", "✅ Duration match": f"{'✓ exact' if abs(calc_dur - resampled_samples / codec.sample_rate) < 0.05 else '⚠ mismatch'}", "⚡ Encode time": f"{t_enc*1000:.1f} ms", "💾 Latent tensor size": f"{latent_np.nbytes / 1024:.1f} KB (float32)", "📊 Latent value range": f"[{latent_np.min():.4f}, {latent_np.max():.4f}]", "📊 Latent mean / std": f"{latent_np.mean():.4f} / {latent_np.std():.4f}", } stats_md = format_stats(stats) return latent_list, latent_list, stats_md # ============================= # DECODE # ============================= def decode_audio(latent_list, stats_md_current): if latent_list is None: return None, (stats_md_current or "") + "\n\n⚠️ No latent found. Encode first." t0 = time.perf_counter() try: latent = torch.tensor(latent_list, dtype=torch.float32, device=DEVICE) except Exception as e: return None, f"⚠️ Invalid latent: {e}" if latent.ndim == 2: latent = latent.unsqueeze(0) # (1, T, D) audio = codec.decode(latent) # (B, 1, T_out) t_dec = time.perf_counter() - t0 audio_np = audio.squeeze().detach().cpu().numpy() audio_np = np.nan_to_num(audio_np) audio_np = np.clip(audio_np, -1.0, 1.0) num_frames = latent.shape[1] out_samples = len(audio_np) actual_dur = out_samples / codec.sample_rate calc_dur = codec.frames_to_seconds(num_frames) actual_hop = out_samples // num_frames decode_stats = { "🔢 Latent frames decoded": f"{num_frames}", "🔊 Output samples": f"{out_samples:,}", "⏱ Reconstructed duration": f"{actual_dur:.4f} s", "⏳ Duration from latent": f"{calc_dur:.4f} s", "🔁 Actual output hop": f"{actual_hop} samples/frame (expected {codec.hop_size})", "✅ Formula confirmation": f"T={num_frames} × hop={actual_hop} / sr={codec.sample_rate} = {num_frames * actual_hop / codec.sample_rate:.4f} s", "⚡ Decode time": f"{t_dec*1000:.1f} ms", "📊 Output value range": f"[{audio_np.min():.4f}, {audio_np.max():.4f}]", } decode_md = format_stats(decode_stats) combined = (stats_md_current or "") + "\n\n### Decode Stats\n" + decode_md return (codec.sample_rate, audio_np), combined # ============================= # UI # ============================= css = """ body, .gradio-container { background: #0d0d0d !important; font-family: 'IBM Plex Mono', monospace !important; color: #e0e0e0 !important; } h1, h2, h3 { color: #00e5a0 !important; letter-spacing: 0.08em; } .gr-button { background: #00e5a0 !important; color: #000 !important; font-weight: 700 !important; border-radius: 2px !important; border: none !important; font-family: 'IBM Plex Mono', monospace !important; letter-spacing: 0.05em; } .gr-button:hover { background: #00ffa8 !important; } .gr-box, .gr-panel { background: #151515 !important; border: 1px solid #2a2a2a !important; } table { width: 100%; border-collapse: collapse; font-size: 0.82em; } th { color: #00e5a0; border-bottom: 1px solid #2a2a2a; padding: 4px 8px; text-align: left; } td { padding: 4px 8px; border-bottom: 1px solid #1a1a1a; } td code { background: #1e1e1e; padding: 2px 6px; border-radius: 2px; color: #a8ff78; } """ with gr.Blocks(css=css, title="DACVAE Inspector") as demo: gr.HTML("""

◈ DACVAE CODEC INSPECTOR

Aratako/Semantic-DACVAE-Japanese-32dim  ·  sr={sr} Hz  ·  hop={hop}  ·  frame_rate={fr:.4f} Hz

""".format(sr=codec.sample_rate, hop=codec.hop_size, fr=codec.frame_rate)) latent_state = gr.State() with gr.Row(): # ── Left column ─────────────────────────────── with gr.Column(scale=1): audio_in = gr.Audio(type="filepath", label="Input Audio") with gr.Row(): encode_btn = gr.Button("▶ ENCODE", variant="primary") decode_btn = gr.Button("◀ DECODE", variant="primary") audio_out = gr.Audio(label="Reconstructed Audio", interactive=False) # ── Right column ────────────────────────────── with gr.Column(scale=1): stats_out = gr.Markdown( value="*Stats will appear here after encoding.*", label="Stats" ) with gr.Accordion("Raw Latent JSON (first 3 frames)", open=False): latent_preview = gr.JSON(label="Latent preview") # ── Wire up ─────────────────────────────────────── def encode_and_preview(file): latent_list, _, stats_md = encode_audio(file) if latent_list is None: return None, None, stats_md preview = latent_list[:3] if latent_list else [] return latent_list, preview, stats_md encode_btn.click( fn=encode_and_preview, inputs=audio_in, outputs=[latent_state, latent_preview, stats_out], ) decode_btn.click( fn=decode_audio, inputs=[latent_state, stats_out], outputs=[audio_out, stats_out], ) # ============================= # RUN # ============================= if __name__ == "__main__": demo.launch(share=True)