Spaces:
No application file
No application file
| """ | |
| ACE-Step 1.5 Music Generation + LoRA Training (CPU) | |
| Runs on HuggingFace Spaces free CPU tier. | |
| """ | |
| import os | |
| import sys | |
| import gc | |
| import time | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| # Force CPU, no CUDA | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.environ["TORCHAUDIO_USE_BACKEND"] = "ffmpeg" | |
| os.environ["ACESTEP_DISABLE_TQDM"] = "1" | |
| import torch | |
| torch.set_default_dtype(torch.float32) | |
| import numpy as np | |
| import gradio as gr | |
| import soundfile as sf | |
| # --------------------------------------------------------------------------- | |
| # Clone ACE-Step repo if not present | |
| # --------------------------------------------------------------------------- | |
| REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ace-step-source") | |
| if not os.path.isdir(REPO_DIR): | |
| print("[Setup] Cloning ACE-Step 1.5 repository...") | |
| os.system(f"git clone --depth 1 https://github.com/ace-step/ACE-Step-1.5 {REPO_DIR}") | |
| # Add repo to path | |
| if REPO_DIR not in sys.path: | |
| sys.path.insert(0, REPO_DIR) | |
| # --------------------------------------------------------------------------- | |
| # Lazy-load handler (downloads model on first use) | |
| # --------------------------------------------------------------------------- | |
| _dit_handler = None | |
| _init_status = None | |
| CHECKPOINT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints") | |
| LORA_OUTPUT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lora_output") | |
| CURRENT_LM_SIZE = "1.7B" # Track current LM size | |
| def get_handler(): | |
| """Get or initialize the ACE-Step handler (lazy, first call downloads model).""" | |
| global _dit_handler, _init_status | |
| if _dit_handler is not None and _dit_handler.model is not None: | |
| return _dit_handler, _init_status | |
| from acestep.handler import AceStepHandler | |
| from acestep.model_downloader import ensure_main_model | |
| print("[Init] Ensuring model is downloaded...") | |
| success, msg = ensure_main_model( | |
| checkpoints_dir=Path(CHECKPOINT_DIR), | |
| prefer_source="huggingface", | |
| ) | |
| print(f"[Init] Model download: {msg}") | |
| if not success: | |
| _init_status = f"Model download failed: {msg}" | |
| return None, _init_status | |
| _dit_handler = AceStepHandler() | |
| project_root = os.path.dirname(os.path.abspath(__file__)) | |
| os.environ["ACESTEP_PROJECT_ROOT"] = project_root | |
| status, ok = _dit_handler.initialize_service( | |
| project_root=project_root, | |
| config_path="acestep-v15-turbo", | |
| device="cpu", | |
| use_flash_attention=False, | |
| compile_model=False, | |
| offload_to_cpu=False, | |
| offload_dit_to_cpu=False, | |
| quantization=None, | |
| use_mlx_dit=False, | |
| ) | |
| _init_status = status | |
| if not ok: | |
| print(f"[Init] FAILED: {status}") | |
| _dit_handler = None | |
| return None, _init_status | |
| # Force float32 on everything | |
| _dit_handler.dtype = torch.float32 | |
| if _dit_handler.model is not None: | |
| _dit_handler.model = _dit_handler.model.float().to("cpu") | |
| if _dit_handler.vae is not None: | |
| _dit_handler.vae = _dit_handler.vae.float().to("cpu") | |
| if _dit_handler.text_encoder is not None: | |
| _dit_handler.text_encoder = _dit_handler.text_encoder.float().to("cpu") | |
| print(f"[Init] OK: {status}") | |
| return _dit_handler, _init_status | |
| def get_trained_loras(): | |
| """List available trained LoRAs.""" | |
| loras = ["None (no LoRA)"] | |
| if os.path.isdir(LORA_OUTPUT_DIR): | |
| for name in sorted(os.listdir(LORA_OUTPUT_DIR)): | |
| lora_dir = os.path.join(LORA_OUTPUT_DIR, name) | |
| if os.path.isdir(lora_dir): | |
| # Check for any .safetensors or .pt files | |
| for f in os.listdir(lora_dir): | |
| if f.endswith((".safetensors", ".pt", ".bin")): | |
| loras.append(name) | |
| break | |
| return loras | |
| # --------------------------------------------------------------------------- | |
| # Generate Tab | |
| # --------------------------------------------------------------------------- | |
| def generate_music( | |
| caption, | |
| lyrics, | |
| instrumental, | |
| bpm, | |
| duration, | |
| seed, | |
| inference_steps, | |
| lm_size, | |
| lora_choice, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Generate music from text prompt on CPU.""" | |
| t0 = time.time() | |
| handler, status = get_handler() | |
| if handler is None: | |
| return None, f"Model not ready: {status}" | |
| # Apply trained LoRA if selected | |
| if lora_choice and lora_choice != "None (no LoRA)": | |
| lora_dir = os.path.join(LORA_OUTPUT_DIR, lora_choice) | |
| if os.path.isdir(lora_dir): | |
| try: | |
| handler.load_lora(lora_dir) | |
| print(f"[Gen] Loaded LoRA: {lora_choice}") | |
| except Exception as e: | |
| print(f"[Gen] LoRA load failed: {e}") | |
| # TODO: LM size switching requires re-downloading the LM model | |
| # For now, log the selected size | |
| if lm_size != CURRENT_LM_SIZE: | |
| print(f"[Gen] LM size {lm_size} requested (current: {CURRENT_LM_SIZE})") | |
| # Clamp values | |
| duration = max(10, min(float(duration), 120)) # cap at 120s for CPU | |
| inference_steps = max(1, min(int(inference_steps), 32)) | |
| bpm_val = int(bpm) if bpm and int(bpm) > 0 else None | |
| seed_val = int(seed) if seed and int(seed) >= 0 else -1 | |
| try: | |
| result = handler.generate_music( | |
| captions=caption or "upbeat electronic dance music", | |
| lyrics=lyrics or "[Instrumental]", | |
| bpm=bpm_val, | |
| audio_duration=duration, | |
| inference_steps=inference_steps, | |
| guidance_scale=1.0, # turbo model, no CFG needed | |
| use_random_seed=(seed_val < 0), | |
| seed=str(seed_val) if seed_val >= 0 else "", | |
| batch_size=1, | |
| task_type="text2music", | |
| vocal_language="en", | |
| shift=1.0, | |
| infer_method="ode", | |
| progress=None, | |
| ) | |
| elapsed = time.time() - t0 | |
| if not result.get("success", False): | |
| error = result.get("error", result.get("status_message", "Unknown error")) | |
| return None, f"Generation failed: {error}" | |
| audios = result.get("audios", []) | |
| if not audios: | |
| return None, "No audio generated" | |
| audio_tensor = audios[0].get("tensor") | |
| sample_rate = audios[0].get("sample_rate", 48000) | |
| if audio_tensor is None: | |
| return None, "Audio tensor is None" | |
| # Convert to numpy | |
| if isinstance(audio_tensor, torch.Tensor): | |
| audio_np = audio_tensor.cpu().float().numpy() | |
| else: | |
| audio_np = np.array(audio_tensor, dtype=np.float32) | |
| # Save to temp file | |
| tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| # soundfile expects (samples, channels) | |
| if audio_np.ndim == 2: | |
| audio_np = audio_np.T # (channels, samples) -> (samples, channels) | |
| sf.write(tmp.name, audio_np, sample_rate) | |
| status_msg = ( | |
| f"Generated in {elapsed:.1f}s | " | |
| f"Duration: {duration}s | Steps: {inference_steps} | " | |
| f"Seed: {seed_val}" | |
| ) | |
| return tmp.name, status_msg | |
| except Exception as e: | |
| import traceback | |
| return None, f"Error: {e}\n{traceback.format_exc()}" | |
| finally: | |
| gc.collect() | |
| # --------------------------------------------------------------------------- | |
| # Train LoRA Tab | |
| # --------------------------------------------------------------------------- | |
| def train_lora( | |
| audio_files, | |
| lora_name, | |
| epochs, | |
| learning_rate, | |
| lora_rank, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Train a LoRA adapter from uploaded audio files on CPU.""" | |
| if not audio_files: | |
| return "No audio files uploaded." | |
| handler, status = get_handler() | |
| if handler is None: | |
| return f"Model not ready: {status}" | |
| lora_name = lora_name.strip() or "my_lora" | |
| epochs = max(1, min(int(epochs), 10)) | |
| lr = float(learning_rate) | |
| rank = max(1, min(int(lora_rank), 64)) | |
| output_dir = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), "lora_output", lora_name | |
| ) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Create a temp directory for audio files | |
| audio_dir = os.path.join(output_dir, "audio_input") | |
| os.makedirs(audio_dir, exist_ok=True) | |
| # Copy uploaded files | |
| for f in audio_files: | |
| src = f.name if hasattr(f, "name") else str(f) | |
| dst = os.path.join(audio_dir, os.path.basename(src)) | |
| shutil.copy2(src, dst) | |
| log_lines = [] | |
| log_lines.append(f"LoRA Training: '{lora_name}'") | |
| log_lines.append(f"Audio files: {len(audio_files)}") | |
| log_lines.append(f"Epochs: {epochs}, LR: {lr}, Rank: {rank}") | |
| log_lines.append(f"Output: {output_dir}") | |
| log_lines.append("") | |
| try: | |
| # Preprocessing step: encode audio files to tensors | |
| log_lines.append("[Step 1/2] Preprocessing audio files...") | |
| tensor_dir = os.path.join(output_dir, "preprocessed_tensors") | |
| os.makedirs(tensor_dir, exist_ok=True) | |
| from acestep.training_v2.preprocess import preprocess_audio_files | |
| preprocess_result = preprocess_audio_files( | |
| audio_dir=audio_dir, | |
| output_dir=tensor_dir, | |
| checkpoint_dir=CHECKPOINT_DIR, | |
| variant="turbo", | |
| max_duration=60.0, | |
| device="cpu", | |
| precision="float32", | |
| ) | |
| processed = preprocess_result.get("processed", 0) | |
| total = preprocess_result.get("total", 0) | |
| failed = preprocess_result.get("failed", 0) | |
| log_lines.append(f" Preprocessed: {processed}/{total} (failed: {failed})") | |
| if processed == 0: | |
| log_lines.append("ERROR: No files were preprocessed successfully.") | |
| return "\n".join(log_lines) | |
| # Training step | |
| log_lines.append("[Step 2/2] Training LoRA adapter...") | |
| from acestep.training_v2.model_loader import load_decoder_for_training | |
| from acestep.training_v2.trainer_fixed import FixedLoRATrainer | |
| from acestep.training_v2.fixed_lora_module import AdapterConfig | |
| from acestep.training_v2.configs import TrainingConfigV2 | |
| # Load model for training | |
| model = load_decoder_for_training( | |
| checkpoint_dir=CHECKPOINT_DIR, | |
| variant="turbo", | |
| device="cpu", | |
| precision="float32", | |
| ) | |
| adapter_cfg = AdapterConfig( | |
| rank=rank, | |
| alpha=rank, | |
| dropout=0.0, | |
| adapter_type="lora", | |
| ) | |
| train_cfg = TrainingConfigV2( | |
| checkpoint_dir=CHECKPOINT_DIR, | |
| model_variant="turbo", | |
| dataset_dir=tensor_dir, | |
| output_dir=output_dir, | |
| max_epochs=epochs, | |
| batch_size=1, | |
| learning_rate=lr, | |
| device="cpu", | |
| precision="float32", | |
| seed=42, | |
| num_workers=0, | |
| pin_memory=False, | |
| ) | |
| trainer = FixedLoRATrainer(model, adapter_cfg, train_cfg) | |
| step_count = 0 | |
| last_loss = 0.0 | |
| for update in trainer.train(): | |
| if hasattr(update, "step"): | |
| step_count = update.step | |
| last_loss = update.loss | |
| if step_count % 5 == 0: | |
| log_lines.append(f" Step {step_count}: loss={last_loss:.4f}") | |
| elif isinstance(update, tuple) and len(update) >= 2: | |
| step_count = update[0] | |
| last_loss = update[1] | |
| if step_count % 5 == 0: | |
| log_lines.append(f" Step {step_count}: loss={last_loss:.4f}") | |
| log_lines.append(f"Training complete! Final step: {step_count}, loss: {last_loss:.4f}") | |
| log_lines.append(f"LoRA saved to: {output_dir}") | |
| # Cleanup | |
| del model, trainer | |
| gc.collect() | |
| except Exception as e: | |
| import traceback | |
| log_lines.append(f"ERROR: {e}") | |
| log_lines.append(traceback.format_exc()) | |
| return "\n".join(log_lines) | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| def build_ui(): | |
| theme = gr.themes.Default() | |
| try: | |
| theme = gr.Theme.from_hub("NoCrypt/miku") | |
| except Exception: | |
| pass | |
| with gr.Blocks( | |
| theme=theme, | |
| title="ACE-Step 1.5 CPU", | |
| css=""" | |
| .main-title { text-align: center; margin-bottom: 0.5em; } | |
| .status-box { font-family: monospace; font-size: 0.85em; } | |
| """, | |
| ) as demo: | |
| gr.HTML("<h1 class='main-title'>ACE-Step 1.5 Music Generation (CPU)</h1>") | |
| gr.HTML( | |
| "<p style='text-align:center;'>Text-to-music generation and LoRA training, " | |
| "running entirely on CPU. Based on " | |
| "<a href='https://github.com/ace-step/ACE-Step-1.5'>ACE-Step 1.5</a>.</p>" | |
| ) | |
| with gr.Tabs(): | |
| # ---- Generate Tab ---- | |
| with gr.Tab("Generate Music"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| caption_input = gr.Textbox( | |
| label="Music Description", | |
| placeholder="e.g. upbeat electronic dance music, 120 BPM", | |
| lines=3, | |
| value="upbeat electronic dance music, energetic synth leads, driving bassline", | |
| ) | |
| lyrics_input = gr.Textbox( | |
| label="Lyrics (use [Instrumental] for no vocals)", | |
| placeholder="[Instrumental]", | |
| lines=3, | |
| value="[Instrumental]", | |
| ) | |
| instrumental_cb = gr.Checkbox( | |
| label="Instrumental (no vocals)", | |
| value=True, | |
| ) | |
| with gr.Column(scale=1): | |
| bpm_input = gr.Number( | |
| label="BPM (0 = auto)", | |
| value=120, | |
| minimum=0, | |
| maximum=300, | |
| ) | |
| duration_input = gr.Slider( | |
| label="Duration (seconds)", | |
| minimum=10, | |
| maximum=120, | |
| value=10, | |
| step=5, | |
| ) | |
| seed_input = gr.Number( | |
| label="Seed (-1 = random)", | |
| value=-1, | |
| ) | |
| steps_input = gr.Slider( | |
| label="Inference Steps (fewer = faster)", | |
| minimum=1, | |
| maximum=32, | |
| value=8, | |
| step=1, | |
| ) | |
| lm_size_input = gr.Dropdown( | |
| label="LM Model Size", | |
| choices=["0.6B (fast)", "1.7B (balanced)", "4B (best quality)"], | |
| value="1.7B (balanced)", | |
| info="Language model for music understanding", | |
| ) | |
| lora_select = gr.Dropdown( | |
| label="Use Trained LoRA", | |
| choices=get_trained_loras(), | |
| value="None (no LoRA)", | |
| info="Select a LoRA you trained to apply it", | |
| ) | |
| generate_btn = gr.Button("Generate Music", variant="primary") | |
| with gr.Row(): | |
| audio_output = gr.Audio( | |
| label="Generated Audio", | |
| type="filepath", | |
| ) | |
| gen_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| elem_classes="status-box", | |
| ) | |
| generate_btn.click( | |
| fn=generate_music, | |
| inputs=[ | |
| caption_input, | |
| lyrics_input, | |
| instrumental_cb, | |
| bpm_input, | |
| duration_input, | |
| seed_input, | |
| steps_input, | |
| lm_size_input, | |
| lora_select, | |
| ], | |
| outputs=[audio_output, gen_status], | |
| ) | |
| # ---- Train LoRA Tab ---- | |
| with gr.Tab("Train LoRA"): | |
| gr.Markdown( | |
| "### Train a LoRA adapter on your audio files\n" | |
| "Upload WAV/MP3/FLAC files to fine-tune the model. " | |
| "Training runs on CPU so keep epochs low and files short." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_upload = gr.File( | |
| label="Upload Audio Files", | |
| file_count="multiple", | |
| file_types=["audio"], | |
| ) | |
| lora_name_input = gr.Textbox( | |
| label="LoRA Name", | |
| value="my_lora", | |
| ) | |
| with gr.Column(): | |
| epochs_input = gr.Slider( | |
| label="Epochs", | |
| minimum=1, | |
| maximum=10, | |
| value=1, | |
| step=1, | |
| ) | |
| lr_input = gr.Number( | |
| label="Learning Rate", | |
| value=1e-4, | |
| ) | |
| rank_input = gr.Slider( | |
| label="LoRA Rank", | |
| minimum=1, | |
| maximum=64, | |
| value=8, | |
| step=1, | |
| ) | |
| train_btn = gr.Button("Start Training", variant="primary") | |
| train_log = gr.Textbox( | |
| label="Training Log", | |
| interactive=False, | |
| lines=15, | |
| elem_classes="status-box", | |
| ) | |
| def train_and_refresh(*args): | |
| log = train_lora(*args) | |
| new_loras = get_trained_loras() | |
| return log, gr.update(choices=new_loras, value=new_loras[-1] if len(new_loras) > 1 else "None (no LoRA)") | |
| train_btn.click( | |
| fn=train_and_refresh, | |
| inputs=[ | |
| audio_upload, | |
| lora_name_input, | |
| epochs_input, | |
| lr_input, | |
| rank_input, | |
| ], | |
| outputs=[train_log, lora_select], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_ui() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| ssr_mode=False, | |
| ) | |