Humair332 commited on
Commit
21231d9
Β·
verified Β·
1 Parent(s): 947815d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -24
app.py CHANGED
@@ -6,6 +6,7 @@ from scipy.signal import resample
6
  from dataclasses import dataclass
7
  from huggingface_hub import hf_hub_download
8
 
 
9
  # =============================
10
  # SIMPLE DACVAE WRAPPER
11
  # =============================
@@ -30,11 +31,13 @@ class SimpleDACCodec:
30
 
31
  @torch.inference_mode()
32
  def encode(self, audio):
33
- z = self.model.encode(audio) # (B, D, T)
34
- return z.transpose(1, 2) # (B, T, D)
 
35
 
36
  @torch.inference_mode()
37
  def decode(self, latent):
 
38
  z = latent.transpose(1, 2)
39
  return self.model.decode(z)
40
 
@@ -52,6 +55,7 @@ codec = SimpleDACCodec.load(device=DEVICE)
52
  def load_audio(path):
53
  audio, sr = sf.read(path, dtype="float32")
54
 
 
55
  if audio.ndim > 1:
56
  audio = np.mean(audio, axis=1)
57
 
@@ -75,63 +79,72 @@ def to_tensor(audio):
75
  # =============================
76
  def encode_audio(file):
77
  if file is None:
78
- raise gr.Error("Please upload an audio file first.")
79
 
80
  audio, sr = load_audio(file)
81
  audio = resample_audio(audio, sr, codec.sample_rate)
82
-
83
  wav = to_tensor(audio).to(DEVICE)
84
- latent = codec.encode(wav)
85
 
86
- # keep as numpy (NOT list β†’ avoids huge lag)
87
- return latent.cpu().numpy()
 
 
88
 
89
 
90
  # =============================
91
  # DECODE
92
  # =============================
93
- def decode_audio(latent):
94
- if latent is None:
95
- raise gr.Error("No latent available. Click Encode first.")
96
 
97
- latent = torch.tensor(latent, dtype=torch.float32).to(DEVICE)
 
 
 
 
98
 
99
  if latent.ndim == 2:
100
  latent = latent.unsqueeze(0)
101
 
102
  audio = codec.decode(latent)
103
- audio = audio.squeeze().cpu().numpy()
 
 
 
 
104
 
105
  return (codec.sample_rate, audio)
106
 
107
 
108
  # =============================
109
- # UI (SINGLE WINDOW)
110
  # =============================
111
  with gr.Blocks() as demo:
112
- gr.Markdown("## 🎧 Simple DAC Audio Codec")
113
-
114
- audio_in = gr.Audio(type="filepath", label="Upload Audio")
115
 
116
- encode_btn = gr.Button("Encode")
117
- decode_btn = gr.Button("Decode")
118
 
119
- latent_state = gr.State() # πŸ”₯ hidden storage (best practice)
 
 
 
 
120
 
121
- audio_out = gr.Audio(label="Reconstructed Audio")
 
 
122
 
123
- # Encode β†’ store in state
124
  encode_btn.click(
125
  fn=encode_audio,
126
  inputs=audio_in,
127
- outputs=latent_state
128
  )
129
 
130
- # Decode from state
131
  decode_btn.click(
132
  fn=decode_audio,
133
  inputs=latent_state,
134
- outputs=audio_out
135
  )
136
 
137
 
 
6
  from dataclasses import dataclass
7
  from huggingface_hub import hf_hub_download
8
 
9
+
10
  # =============================
11
  # SIMPLE DACVAE WRAPPER
12
  # =============================
 
31
 
32
  @torch.inference_mode()
33
  def encode(self, audio):
34
+ # audio: (1, 1, T)
35
+ z = self.model.encode(audio) # (B, D, T)
36
+ return z.transpose(1, 2) # (B, T, D)
37
 
38
  @torch.inference_mode()
39
  def decode(self, latent):
40
+ # latent: (B, T, D)
41
  z = latent.transpose(1, 2)
42
  return self.model.decode(z)
43
 
 
55
  def load_audio(path):
56
  audio, sr = sf.read(path, dtype="float32")
57
 
58
+ # mono
59
  if audio.ndim > 1:
60
  audio = np.mean(audio, axis=1)
61
 
 
79
  # =============================
80
  def encode_audio(file):
81
  if file is None:
82
+ raise ValueError("Please upload an audio file first.")
83
 
84
  audio, sr = load_audio(file)
85
  audio = resample_audio(audio, sr, codec.sample_rate)
 
86
  wav = to_tensor(audio).to(DEVICE)
 
87
 
88
+ latent = codec.encode(wav) # (B, T, D)
89
+
90
+ latent_list = latent.detach().cpu().numpy().tolist()
91
+ return latent_list, latent_list # one for display, one for hidden state
92
 
93
 
94
  # =============================
95
  # DECODE
96
  # =============================
97
+ def decode_audio(latent_list):
98
+ if latent_list is None:
99
+ raise ValueError("No latent found. Click Encode first.")
100
 
101
+ # Convert nested list to tensor safely
102
+ try:
103
+ latent = torch.tensor(latent_list, dtype=torch.float32, device=DEVICE)
104
+ except Exception as e:
105
+ raise ValueError(f"Invalid latent data: {e}")
106
 
107
  if latent.ndim == 2:
108
  latent = latent.unsqueeze(0)
109
 
110
  audio = codec.decode(latent)
111
+ audio = audio.squeeze().detach().cpu().numpy()
112
+
113
+ # clip just in case
114
+ audio = np.nan_to_num(audio)
115
+ audio = np.clip(audio, -1.0, 1.0)
116
 
117
  return (codec.sample_rate, audio)
118
 
119
 
120
  # =============================
121
+ # UI
122
  # =============================
123
  with gr.Blocks() as demo:
124
+ gr.Markdown("## 🎧 Simple DAC Audio Codec (Single Window)")
 
 
125
 
126
+ latent_state = gr.State()
 
127
 
128
+ with gr.Row():
129
+ with gr.Column(scale=1):
130
+ audio_in = gr.Audio(type="filepath", label="Upload Audio")
131
+ encode_btn = gr.Button("Encode")
132
+ decode_btn = gr.Button("Decode")
133
 
134
+ with gr.Column(scale=1):
135
+ latent_out = gr.JSON(label="Latent")
136
+ audio_out = gr.Audio(label="Reconstructed Audio")
137
 
 
138
  encode_btn.click(
139
  fn=encode_audio,
140
  inputs=audio_in,
141
+ outputs=[latent_out, latent_state],
142
  )
143
 
 
144
  decode_btn.click(
145
  fn=decode_audio,
146
  inputs=latent_state,
147
+ outputs=audio_out,
148
  )
149
 
150