Coiland commited on
Commit
96b2a8c
·
verified ·
1 Parent(s): 07aaa94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -81
app.py CHANGED
@@ -1,90 +1,76 @@
1
- import torch
2
  import gradio as gr
 
3
  import random
4
- import traceback
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # -------------------------------------------------
7
- # FORCE CPU (FREE TIER SAFE)
8
- # -------------------------------------------------
9
  device = "cpu"
10
  print("Running in CPU mode (Free tier safe)")
11
 
12
- # -------------------------------------------------
13
- # IMPORT YOUR ORIGINAL UTILITIES
14
- # -------------------------------------------------
15
- from diffrhythm.infer.infer_utils import prepare_model, infer_music
16
-
17
- # -------------------------------------------------
18
- # LOAD MODEL (REDUCED MEMORY MODE)
19
- # -------------------------------------------------
20
- try:
21
- cfm, tokenizer, muq, vae, eval_model, eval_muq = prepare_model(
22
- max_frames=2048, # Reduced from 6144
23
- device=device
 
 
 
 
 
 
 
 
24
  )
25
- print("Model loaded successfully.")
26
- except Exception as e:
27
- print("Model loading failed:")
28
- print(traceback.format_exc())
29
- raise e
30
-
31
- # -------------------------------------------------
32
- # SAFE GENERATION FUNCTION
33
- # -------------------------------------------------
34
- def generate_music(
35
- prompt,
36
- duration,
37
- guidance_scale,
38
- seed
39
- ):
40
- try:
41
- if seed == -1:
42
- seed = random.randint(0, 999999)
43
-
44
- torch.manual_seed(int(seed))
45
-
46
- audio = infer_music(
47
- cfm=cfm,
48
- tokenizer=tokenizer,
49
- muq=muq,
50
- vae=vae,
51
- eval_model=eval_model,
52
- eval_muq=eval_muq,
53
- prompt=prompt,
54
- duration=int(duration),
55
- guidance_scale=float(guidance_scale),
56
- device=device
57
- )
58
-
59
- return audio
60
-
61
- except Exception as e:
62
- return f"Error during generation:\n{traceback.format_exc()}"
63
-
64
- # -------------------------------------------------
65
- # GRADIO UI
66
- # -------------------------------------------------
67
- with gr.Blocks() as demo:
68
-
69
- gr.Markdown("# 🎵 DiffRhythm (CPU Safe Version)")
70
- gr.Markdown("Running on Free CPU mode. Generation will be slow.")
71
-
72
- with gr.Row():
73
- prompt = gr.Textbox(label="Music Prompt", placeholder="Describe the music...")
74
-
75
- with gr.Row():
76
- duration = gr.Slider(5, 30, value=10, step=1, label="Duration (seconds)")
77
- guidance = gr.Slider(1, 10, value=5, step=0.5, label="Guidance Scale")
78
- seed = gr.Number(value=-1, label="Seed (-1 = random)")
79
-
80
- generate_btn = gr.Button("Generate Music")
81
-
82
- output_audio = gr.Audio(label="Generated Audio")
83
-
84
- generate_btn.click(
85
- fn=generate_music,
86
- inputs=[prompt, duration, guidance, seed],
87
- outputs=output_audio
88
  )
89
 
90
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
  import random
4
+ import numpy as np
5
+
6
+ from diffrhythm.infer.infer_utils import (
7
+ prepare_model,
8
+ get_lrc_token,
9
+ get_audio_style_prompt,
10
+ get_text_style_prompt,
11
+ get_negative_style_prompt,
12
+ get_reference_latent
13
+ )
14
+
15
+ from diffrhythm.infer.infer import inference
16
 
17
+
18
+ # FORCE CPU FOR FREE TIER
 
19
  device = "cpu"
20
  print("Running in CPU mode (Free tier safe)")
21
 
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+
24
+ cfm, tokenizer, muq, vae, eval_model, eval_muq = prepare_model(
25
+ max_frames=2048,
26
+ device=device
27
+ )
28
+
29
+
30
+ def generate_music(lrc_text):
31
+ torch.manual_seed(0)
32
+
33
+ lrc_prompt, start_time, end_frame, song_duration = get_lrc_token(
34
+ 2048, lrc_text, tokenizer, 95, device
35
+ )
36
+
37
+ style_prompt = get_text_style_prompt(muq, "emotional piano")
38
+ negative_style_prompt = get_negative_style_prompt(device)
39
+
40
+ latent_prompt, pred_frames = get_reference_latent(
41
+ device, 2048, False, None, None, vae
42
  )
43
+
44
+ song = inference(
45
+ cfm_model=cfm,
46
+ vae_model=vae,
47
+ eval_model=eval_model,
48
+ eval_muq=eval_muq,
49
+ cond=latent_prompt,
50
+ text=lrc_prompt,
51
+ duration=end_frame,
52
+ style_prompt=style_prompt,
53
+ negative_style_prompt=negative_style_prompt,
54
+ steps=10,
55
+ cfg_strength=3.0,
56
+ sway_sampling_coef=None,
57
+ start_time=start_time,
58
+ file_type="mp3",
59
+ vocal_flag=False,
60
+ odeint_method="euler",
61
+ pred_frames=pred_frames,
62
+ batch_infer_num=1,
63
+ song_duration=song_duration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  )
65
 
66
+ return song
67
+
68
+
69
+ demo = gr.Interface(
70
+ fn=generate_music,
71
+ inputs=gr.Textbox(lines=10, label="LRC Lyrics"),
72
+ outputs=gr.Audio(type="filepath")
73
+ )
74
+
75
+ if __name__ == "__main__":
76
+ demo.launch()