multimodalart HF Staff commited on
Commit
b7bca64
·
verified ·
1 Parent(s): d3ab0d1

[Admin maintenance] Migrate to ZeroGPU

Browse files
Files changed (2) hide show
  1. app.py +24 -6
  2. requirements.txt +7 -9
app.py CHANGED
@@ -1,5 +1,7 @@
 
1
  import gradio as gr
2
  import json
 
3
  import torch
4
  import wavio
5
  import numpy as np
@@ -15,7 +17,6 @@ sys.path.insert(0, "diffusers/src")
15
  from diffusers import DDPMScheduler
16
  from models import MusicAudioDiffusion
17
  from gradio import Markdown
18
- import spaces
19
 
20
 
21
  # Automatic device detection
@@ -139,16 +140,26 @@ class MusicFeaturePredictor:
139
  num_return_sequences=1,
140
  )
141
 
142
- generated_chords = self.chords_tokenizer.decode(
 
 
 
 
 
 
 
 
143
  generated_chords[0],
144
- skip_special_tokens=True,
145
  clean_up_tokenization_spaces=True,
146
- ).split(" n ")
 
 
147
 
148
  predicted_chords, predicted_chords_times = [], []
149
- for item in generated_chords:
150
  c, ct = item.split(" at ")
151
- predicted_chords.append(c)
152
  predicted_chords_times.append(float(ct))
153
 
154
  return predicted_beats, predicted_chords, predicted_chords_times
@@ -240,6 +251,13 @@ class Mustango:
240
  return wave[0]
241
 
242
 
 
 
 
 
 
 
 
243
  # Initialize Mustango
244
  mustango = Mustango(device="cpu")
245
  mustango.vae.to(device_type)
 
1
+ import spaces
2
  import gradio as gr
3
  import json
4
+ import re
5
  import torch
6
  import wavio
7
  import numpy as np
 
17
  from diffusers import DDPMScheduler
18
  from models import MusicAudioDiffusion
19
  from gradio import Markdown
 
20
 
21
 
22
  # Automatic device detection
 
140
  num_return_sequences=1,
141
  )
142
 
143
+ # Mustango's chord T5 was finetuned with '\n' between chord entries.
144
+ # SentencePiece T5 has no '\n' in its vocab, so the trainer encoded each
145
+ # separator as <unk> (id 2) and the model learned to emit <unk>. With
146
+ # skip_special_tokens=True, <unk> is silently stripped, collapsing every
147
+ # entry onto one line and breaking the ' at ' parse. Decode with specials
148
+ # preserved and split on the separator the model emits: '<unk>' or
149
+ # '<unk>n' (an extra 'n' token sometimes follows the <unk>), and on the
150
+ # legacy ' n ' literal that the torch-2.0.1 reference Space produced.
151
+ decoded = self.chords_tokenizer.decode(
152
  generated_chords[0],
153
+ skip_special_tokens=False,
154
  clean_up_tokenization_spaces=True,
155
+ )
156
+ decoded = decoded.replace("<pad>", "").replace("</s>", "").strip()
157
+ items = [p.strip() for p in re.split(r"\s*<unk>\s*n?\s+|\s+n\s+", decoded) if p.strip()]
158
 
159
  predicted_chords, predicted_chords_times = [], []
160
+ for item in items:
161
  c, ct = item.split(" at ")
162
+ predicted_chords.append(c.strip())
163
  predicted_chords_times.append(float(ct))
164
 
165
  return predicted_beats, predicted_chords, predicted_chords_times
 
251
  return wave[0]
252
 
253
 
254
+ # Disable TF32 / nondeterministic cuDNN kernels so chord/beats T5 sampling
255
+ # stays as close as possible to the torch 2.0.1 reference Space's output.
256
+ torch.backends.cuda.matmul.allow_tf32 = False
257
+ torch.backends.cudnn.allow_tf32 = False
258
+ torch.backends.cudnn.deterministic = True
259
+ torch.backends.cudnn.benchmark = False
260
+
261
  # Initialize Mustango
262
  mustango = Mustango(device="cpu")
263
  mustango.vae.to(device_type)
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- torch==2.0.1
2
- torchaudio==2.0.2
3
- torchvision==0.15.2
4
  transformers==4.31.0
5
  accelerate==0.21.0
6
  datasets==2.1.0
@@ -10,10 +10,10 @@ huggingface_hub>=0.33.5
10
  importlib_metadata==6.3.0
11
  librosa==0.9.2
12
  matplotlib==3.5.2
13
- numpy==1.23.0
14
  omegaconf==2.3.0
15
  packaging==23.1
16
- pandas==1.4.1
17
  progressbar33==2.4
18
  protobuf==3.20.*
19
  resampy==0.4.2
@@ -23,10 +23,8 @@ scikit_image==0.19.3
23
  scikit_learn==1.2.2
24
  scipy==1.8.0
25
  soundfile==0.12.1
26
- #ssr_eval==0.0.6
27
  torchlibrosa==0.1.0
28
  tqdm==4.63.1
29
- wandb==0.12.14
30
- #ipython==8.12.0
31
  wavio==0.0.7
32
- hf-xet==1.1.8
 
 
1
+ torch==2.8.0
2
+ torchaudio==2.8.0
3
+ torchvision==0.23.0
4
  transformers==4.31.0
5
  accelerate==0.21.0
6
  datasets==2.1.0
 
10
  importlib_metadata==6.3.0
11
  librosa==0.9.2
12
  matplotlib==3.5.2
13
+ numpy<2
14
  omegaconf==2.3.0
15
  packaging==23.1
16
+ pandas
17
  progressbar33==2.4
18
  protobuf==3.20.*
19
  resampy==0.4.2
 
23
  scikit_learn==1.2.2
24
  scipy==1.8.0
25
  soundfile==0.12.1
 
26
  torchlibrosa==0.1.0
27
  tqdm==4.63.1
 
 
28
  wavio==0.0.7
29
+ hf-xet==1.1.8
30
+ diffusers==0.18.2