| import gradio as gr |
| import torch |
| import numpy as np |
| import soundfile as sf |
| from scipy.signal import resample as scipy_resample |
| from dataclasses import dataclass, field |
| from huggingface_hub import hf_hub_download |
| import time |
| import json |
|
|
| |
| |
| |
|
|
| @dataclass |
| class SimpleDACCodec: |
| model: torch.nn.Module |
| sample_rate: int |
| hop_size: int |
| device: torch.device |
|
|
| @classmethod |
| def load(cls, repo_id="Aratako/Semantic-DACVAE-Japanese-32dim", device="cpu"): |
| from dacvae import DACVAE |
| weights_path = hf_hub_download(repo_id=repo_id, filename="weights.pth") |
| model = DACVAE.load(weights_path).eval().to(device) |
| sr = int(model.sample_rate) |
|
|
| |
| |
| |
| |
| |
| probe_len = sr |
| dummy = torch.zeros(1, 1, probe_len, device=device, |
| dtype=next(model.parameters()).dtype) |
| with torch.inference_mode(): |
| z = model.encode(dummy) |
| t_latent = z.shape[2] |
| hop = probe_len // t_latent |
|
|
| print(f"[codec] sample_rate={sr} probe_frames={t_latent} " |
| f"hop={hop} frame_rate={sr/hop:.4f} Hz", flush=True) |
|
|
| return cls( |
| model = model, |
| sample_rate = sr, |
| hop_size = hop, |
| device = torch.device(device), |
| ) |
|
|
| @property |
| def frame_rate(self) -> float: |
| """Latent frames per second.""" |
| return self.sample_rate / self.hop_size |
|
|
| def frames_to_seconds(self, num_frames: int) -> float: |
| """Convert latent frame count -> audio duration in seconds.""" |
| return num_frames * self.hop_size / self.sample_rate |
|
|
| @torch.inference_mode() |
| def encode(self, audio: torch.Tensor) -> torch.Tensor: |
| """audio: (1, 1, T) -> latent: (1, T_latent, D)""" |
| z = self.model.encode(audio) |
| return z.transpose(1, 2) |
|
|
| @torch.inference_mode() |
| def decode(self, latent: torch.Tensor) -> torch.Tensor: |
| """latent: (B, T_latent, D) -> audio: (B, 1, T)""" |
| return self.model.decode(latent.transpose(1, 2)) |
|
|
|
|
| |
| |
| |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"[init] Using device: {DEVICE}") |
| codec = SimpleDACCodec.load(device=DEVICE) |
| print(f"[init] Codec ready. Frame rate = {codec.frame_rate:.4f} Hz " |
| f"(hop={codec.hop_size}, sr={codec.sample_rate})") |
|
|
|
|
| |
| |
| |
|
|
| def load_audio(path: str) -> tuple[np.ndarray, int]: |
| audio, sr = sf.read(path, dtype="float32") |
| if audio.ndim > 1: |
| audio = np.mean(audio, axis=1) |
| return audio, sr |
|
|
|
|
| def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: |
| if orig_sr == target_sr: |
| return audio |
| num_samples = int(len(audio) * target_sr / orig_sr) |
| return scipy_resample(audio, num_samples) |
|
|
|
|
| def to_tensor(audio: np.ndarray) -> torch.Tensor: |
| return torch.from_numpy(audio).unsqueeze(0).unsqueeze(0) |
|
|
|
|
| def format_stats(stats: dict) -> str: |
| """Render stats dict as a clean markdown table for display.""" |
| lines = ["| Property | Value |", "|---|---|"] |
| for k, v in stats.items(): |
| lines.append(f"| {k} | `{v}` |") |
| return "\n".join(lines) |
|
|
|
|
| |
| |
| |
|
|
| def encode_audio(file): |
| if file is None: |
| return None, None, "β οΈ Please upload an audio file first." |
|
|
| t0 = time.perf_counter() |
|
|
| |
| audio_orig, sr_orig = load_audio(file) |
| orig_samples = len(audio_orig) |
| orig_duration = orig_samples / sr_orig |
|
|
| audio_resampled = resample_audio(audio_orig, sr_orig, codec.sample_rate) |
| resampled_samples = len(audio_resampled) |
|
|
| wav = to_tensor(audio_resampled).to(DEVICE) |
|
|
| |
| latent = codec.encode(wav) |
| t_enc = time.perf_counter() - t0 |
|
|
| num_frames = latent.shape[1] |
| latent_dim = latent.shape[2] |
| calc_dur = codec.frames_to_seconds(num_frames) |
|
|
| latent_np = latent.squeeze(0).detach().cpu().numpy() |
| latent_list = latent_np.tolist() |
|
|
| |
| stats = { |
| "π Original sample rate": f"{sr_orig} Hz", |
| "π΅ Codec sample rate": f"{codec.sample_rate} Hz", |
| "β± Original duration": f"{orig_duration:.4f} s ({orig_samples:,} samples)", |
| "β± Resampled duration": f"{resampled_samples / codec.sample_rate:.4f} s ({resampled_samples:,} samples)", |
| "π’ Latent frames (T)": f"{num_frames}", |
| "π Latent dim (D)": f"{latent_dim}", |
| "π Encoder hop size": f"{codec.hop_size} samples", |
| "π Latent frame rate": f"{codec.frame_rate:.4f} Hz", |
| "β³ Duration from latent": f"{calc_dur:.4f} s (T Γ hop / sr = {num_frames} Γ {codec.hop_size} / {codec.sample_rate})", |
| "β
Duration match": f"{'β exact' if abs(calc_dur - resampled_samples / codec.sample_rate) < 0.05 else 'β mismatch'}", |
| "β‘ Encode time": f"{t_enc*1000:.1f} ms", |
| "πΎ Latent tensor size": f"{latent_np.nbytes / 1024:.1f} KB (float32)", |
| "π Latent value range": f"[{latent_np.min():.4f}, {latent_np.max():.4f}]", |
| "π Latent mean / std": f"{latent_np.mean():.4f} / {latent_np.std():.4f}", |
| } |
|
|
| stats_md = format_stats(stats) |
| return latent_list, latent_list, stats_md |
|
|
|
|
| |
| |
| |
|
|
| def decode_audio(latent_list, stats_md_current): |
| if latent_list is None: |
| return None, (stats_md_current or "") + "\n\nβ οΈ No latent found. Encode first." |
|
|
| t0 = time.perf_counter() |
|
|
| try: |
| latent = torch.tensor(latent_list, dtype=torch.float32, device=DEVICE) |
| except Exception as e: |
| return None, f"β οΈ Invalid latent: {e}" |
|
|
| if latent.ndim == 2: |
| latent = latent.unsqueeze(0) |
|
|
| audio = codec.decode(latent) |
| t_dec = time.perf_counter() - t0 |
|
|
| audio_np = audio.squeeze().detach().cpu().numpy() |
| audio_np = np.nan_to_num(audio_np) |
| audio_np = np.clip(audio_np, -1.0, 1.0) |
|
|
| num_frames = latent.shape[1] |
| out_samples = len(audio_np) |
| actual_dur = out_samples / codec.sample_rate |
| calc_dur = codec.frames_to_seconds(num_frames) |
| actual_hop = out_samples // num_frames |
|
|
| decode_stats = { |
| "π’ Latent frames decoded": f"{num_frames}", |
| "π Output samples": f"{out_samples:,}", |
| "β± Reconstructed duration": f"{actual_dur:.4f} s", |
| "β³ Duration from latent": f"{calc_dur:.4f} s", |
| "π Actual output hop": f"{actual_hop} samples/frame (expected {codec.hop_size})", |
| "β
Formula confirmation": f"T={num_frames} Γ hop={actual_hop} / sr={codec.sample_rate} = {num_frames * actual_hop / codec.sample_rate:.4f} s", |
| "β‘ Decode time": f"{t_dec*1000:.1f} ms", |
| "π Output value range": f"[{audio_np.min():.4f}, {audio_np.max():.4f}]", |
| } |
|
|
| decode_md = format_stats(decode_stats) |
| combined = (stats_md_current or "") + "\n\n### Decode Stats\n" + decode_md |
|
|
| return (codec.sample_rate, audio_np), combined |
|
|
|
|
| |
| |
| |
|
|
| css = """ |
| body, .gradio-container { |
| background: #0d0d0d !important; |
| font-family: 'IBM Plex Mono', monospace !important; |
| color: #e0e0e0 !important; |
| } |
| h1, h2, h3 { color: #00e5a0 !important; letter-spacing: 0.08em; } |
| .gr-button { |
| background: #00e5a0 !important; |
| color: #000 !important; |
| font-weight: 700 !important; |
| border-radius: 2px !important; |
| border: none !important; |
| font-family: 'IBM Plex Mono', monospace !important; |
| letter-spacing: 0.05em; |
| } |
| .gr-button:hover { background: #00ffa8 !important; } |
| .gr-box, .gr-panel { background: #151515 !important; border: 1px solid #2a2a2a !important; } |
| table { width: 100%; border-collapse: collapse; font-size: 0.82em; } |
| th { color: #00e5a0; border-bottom: 1px solid #2a2a2a; padding: 4px 8px; text-align: left; } |
| td { padding: 4px 8px; border-bottom: 1px solid #1a1a1a; } |
| td code { background: #1e1e1e; padding: 2px 6px; border-radius: 2px; color: #a8ff78; } |
| """ |
|
|
| with gr.Blocks(css=css, title="DACVAE Inspector") as demo: |
|
|
| gr.HTML(""" |
| <link href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;700&display=swap" rel="stylesheet"> |
| <div style="padding: 24px 0 8px 0;"> |
| <h1 style="font-size:1.6em; margin:0; letter-spacing:0.12em;"> |
| β DACVAE CODEC INSPECTOR |
| </h1> |
| <p style="color:#666; margin:4px 0 0 0; font-size:0.78em; letter-spacing:0.06em;"> |
| Aratako/Semantic-DACVAE-Japanese-32dim Β· |
| sr={sr} Hz Β· hop={hop} Β· frame_rate={fr:.4f} Hz |
| </p> |
| </div> |
| """.format(sr=codec.sample_rate, hop=codec.hop_size, fr=codec.frame_rate)) |
|
|
| latent_state = gr.State() |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| audio_in = gr.Audio(type="filepath", label="Input Audio") |
| with gr.Row(): |
| encode_btn = gr.Button("βΆ ENCODE", variant="primary") |
| decode_btn = gr.Button("β DECODE", variant="primary") |
| audio_out = gr.Audio(label="Reconstructed Audio", interactive=False) |
|
|
| |
| with gr.Column(scale=1): |
| stats_out = gr.Markdown( |
| value="*Stats will appear here after encoding.*", |
| label="Stats" |
| ) |
|
|
| with gr.Accordion("Raw Latent JSON (first 3 frames)", open=False): |
| latent_preview = gr.JSON(label="Latent preview") |
|
|
| |
| def encode_and_preview(file): |
| latent_list, _, stats_md = encode_audio(file) |
| if latent_list is None: |
| return None, None, stats_md |
| preview = latent_list[:3] if latent_list else [] |
| return latent_list, preview, stats_md |
|
|
| encode_btn.click( |
| fn=encode_and_preview, |
| inputs=audio_in, |
| outputs=[latent_state, latent_preview, stats_out], |
| ) |
|
|
| decode_btn.click( |
| fn=decode_audio, |
| inputs=[latent_state, stats_out], |
| outputs=[audio_out, stats_out], |
| ) |
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |