ACE-Step-CPU / app.py
Nekochu's picture
Upload folder using huggingface_hub
c5c37c1 verified
"""
ACE-Step 1.5 Music Generation + LoRA Training (CPU)
Runs on HuggingFace Spaces free CPU tier.
"""
import os
import sys
import gc
import time
import tempfile
import shutil
from pathlib import Path
# Force CPU, no CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TORCHAUDIO_USE_BACKEND"] = "ffmpeg"
os.environ["ACESTEP_DISABLE_TQDM"] = "1"
import torch
torch.set_default_dtype(torch.float32)
import numpy as np
import gradio as gr
import soundfile as sf
# ---------------------------------------------------------------------------
# Clone ACE-Step repo if not present
# ---------------------------------------------------------------------------
REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ace-step-source")
if not os.path.isdir(REPO_DIR):
print("[Setup] Cloning ACE-Step 1.5 repository...")
os.system(f"git clone --depth 1 https://github.com/ace-step/ACE-Step-1.5 {REPO_DIR}")
# Add repo to path
if REPO_DIR not in sys.path:
sys.path.insert(0, REPO_DIR)
# ---------------------------------------------------------------------------
# Lazy-load handler (downloads model on first use)
# ---------------------------------------------------------------------------
_dit_handler = None
_init_status = None
CHECKPOINT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints")
LORA_OUTPUT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lora_output")
CURRENT_LM_SIZE = "1.7B" # Track current LM size
def get_handler():
"""Get or initialize the ACE-Step handler (lazy, first call downloads model)."""
global _dit_handler, _init_status
if _dit_handler is not None and _dit_handler.model is not None:
return _dit_handler, _init_status
from acestep.handler import AceStepHandler
from acestep.model_downloader import ensure_main_model
print("[Init] Ensuring model is downloaded...")
success, msg = ensure_main_model(
checkpoints_dir=Path(CHECKPOINT_DIR),
prefer_source="huggingface",
)
print(f"[Init] Model download: {msg}")
if not success:
_init_status = f"Model download failed: {msg}"
return None, _init_status
_dit_handler = AceStepHandler()
project_root = os.path.dirname(os.path.abspath(__file__))
os.environ["ACESTEP_PROJECT_ROOT"] = project_root
status, ok = _dit_handler.initialize_service(
project_root=project_root,
config_path="acestep-v15-turbo",
device="cpu",
use_flash_attention=False,
compile_model=False,
offload_to_cpu=False,
offload_dit_to_cpu=False,
quantization=None,
use_mlx_dit=False,
)
_init_status = status
if not ok:
print(f"[Init] FAILED: {status}")
_dit_handler = None
return None, _init_status
# Force float32 on everything
_dit_handler.dtype = torch.float32
if _dit_handler.model is not None:
_dit_handler.model = _dit_handler.model.float().to("cpu")
if _dit_handler.vae is not None:
_dit_handler.vae = _dit_handler.vae.float().to("cpu")
if _dit_handler.text_encoder is not None:
_dit_handler.text_encoder = _dit_handler.text_encoder.float().to("cpu")
print(f"[Init] OK: {status}")
return _dit_handler, _init_status
def get_trained_loras():
"""List available trained LoRAs."""
loras = ["None (no LoRA)"]
if os.path.isdir(LORA_OUTPUT_DIR):
for name in sorted(os.listdir(LORA_OUTPUT_DIR)):
lora_dir = os.path.join(LORA_OUTPUT_DIR, name)
if os.path.isdir(lora_dir):
# Check for any .safetensors or .pt files
for f in os.listdir(lora_dir):
if f.endswith((".safetensors", ".pt", ".bin")):
loras.append(name)
break
return loras
# ---------------------------------------------------------------------------
# Generate Tab
# ---------------------------------------------------------------------------
def generate_music(
caption,
lyrics,
instrumental,
bpm,
duration,
seed,
inference_steps,
lm_size,
lora_choice,
progress=gr.Progress(track_tqdm=True),
):
"""Generate music from text prompt on CPU."""
t0 = time.time()
handler, status = get_handler()
if handler is None:
return None, f"Model not ready: {status}"
# Apply trained LoRA if selected
if lora_choice and lora_choice != "None (no LoRA)":
lora_dir = os.path.join(LORA_OUTPUT_DIR, lora_choice)
if os.path.isdir(lora_dir):
try:
handler.load_lora(lora_dir)
print(f"[Gen] Loaded LoRA: {lora_choice}")
except Exception as e:
print(f"[Gen] LoRA load failed: {e}")
# TODO: LM size switching requires re-downloading the LM model
# For now, log the selected size
if lm_size != CURRENT_LM_SIZE:
print(f"[Gen] LM size {lm_size} requested (current: {CURRENT_LM_SIZE})")
# Clamp values
duration = max(10, min(float(duration), 120)) # cap at 120s for CPU
inference_steps = max(1, min(int(inference_steps), 32))
bpm_val = int(bpm) if bpm and int(bpm) > 0 else None
seed_val = int(seed) if seed and int(seed) >= 0 else -1
try:
result = handler.generate_music(
captions=caption or "upbeat electronic dance music",
lyrics=lyrics or "[Instrumental]",
bpm=bpm_val,
audio_duration=duration,
inference_steps=inference_steps,
guidance_scale=1.0, # turbo model, no CFG needed
use_random_seed=(seed_val < 0),
seed=str(seed_val) if seed_val >= 0 else "",
batch_size=1,
task_type="text2music",
vocal_language="en",
shift=1.0,
infer_method="ode",
progress=None,
)
elapsed = time.time() - t0
if not result.get("success", False):
error = result.get("error", result.get("status_message", "Unknown error"))
return None, f"Generation failed: {error}"
audios = result.get("audios", [])
if not audios:
return None, "No audio generated"
audio_tensor = audios[0].get("tensor")
sample_rate = audios[0].get("sample_rate", 48000)
if audio_tensor is None:
return None, "Audio tensor is None"
# Convert to numpy
if isinstance(audio_tensor, torch.Tensor):
audio_np = audio_tensor.cpu().float().numpy()
else:
audio_np = np.array(audio_tensor, dtype=np.float32)
# Save to temp file
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
# soundfile expects (samples, channels)
if audio_np.ndim == 2:
audio_np = audio_np.T # (channels, samples) -> (samples, channels)
sf.write(tmp.name, audio_np, sample_rate)
status_msg = (
f"Generated in {elapsed:.1f}s | "
f"Duration: {duration}s | Steps: {inference_steps} | "
f"Seed: {seed_val}"
)
return tmp.name, status_msg
except Exception as e:
import traceback
return None, f"Error: {e}\n{traceback.format_exc()}"
finally:
gc.collect()
# ---------------------------------------------------------------------------
# Train LoRA Tab
# ---------------------------------------------------------------------------
def train_lora(
audio_files,
lora_name,
epochs,
learning_rate,
lora_rank,
progress=gr.Progress(track_tqdm=True),
):
"""Train a LoRA adapter from uploaded audio files on CPU."""
if not audio_files:
return "No audio files uploaded."
handler, status = get_handler()
if handler is None:
return f"Model not ready: {status}"
lora_name = lora_name.strip() or "my_lora"
epochs = max(1, min(int(epochs), 10))
lr = float(learning_rate)
rank = max(1, min(int(lora_rank), 64))
output_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "lora_output", lora_name
)
os.makedirs(output_dir, exist_ok=True)
# Create a temp directory for audio files
audio_dir = os.path.join(output_dir, "audio_input")
os.makedirs(audio_dir, exist_ok=True)
# Copy uploaded files
for f in audio_files:
src = f.name if hasattr(f, "name") else str(f)
dst = os.path.join(audio_dir, os.path.basename(src))
shutil.copy2(src, dst)
log_lines = []
log_lines.append(f"LoRA Training: '{lora_name}'")
log_lines.append(f"Audio files: {len(audio_files)}")
log_lines.append(f"Epochs: {epochs}, LR: {lr}, Rank: {rank}")
log_lines.append(f"Output: {output_dir}")
log_lines.append("")
try:
# Preprocessing step: encode audio files to tensors
log_lines.append("[Step 1/2] Preprocessing audio files...")
tensor_dir = os.path.join(output_dir, "preprocessed_tensors")
os.makedirs(tensor_dir, exist_ok=True)
from acestep.training_v2.preprocess import preprocess_audio_files
preprocess_result = preprocess_audio_files(
audio_dir=audio_dir,
output_dir=tensor_dir,
checkpoint_dir=CHECKPOINT_DIR,
variant="turbo",
max_duration=60.0,
device="cpu",
precision="float32",
)
processed = preprocess_result.get("processed", 0)
total = preprocess_result.get("total", 0)
failed = preprocess_result.get("failed", 0)
log_lines.append(f" Preprocessed: {processed}/{total} (failed: {failed})")
if processed == 0:
log_lines.append("ERROR: No files were preprocessed successfully.")
return "\n".join(log_lines)
# Training step
log_lines.append("[Step 2/2] Training LoRA adapter...")
from acestep.training_v2.model_loader import load_decoder_for_training
from acestep.training_v2.trainer_fixed import FixedLoRATrainer
from acestep.training_v2.fixed_lora_module import AdapterConfig
from acestep.training_v2.configs import TrainingConfigV2
# Load model for training
model = load_decoder_for_training(
checkpoint_dir=CHECKPOINT_DIR,
variant="turbo",
device="cpu",
precision="float32",
)
adapter_cfg = AdapterConfig(
rank=rank,
alpha=rank,
dropout=0.0,
adapter_type="lora",
)
train_cfg = TrainingConfigV2(
checkpoint_dir=CHECKPOINT_DIR,
model_variant="turbo",
dataset_dir=tensor_dir,
output_dir=output_dir,
max_epochs=epochs,
batch_size=1,
learning_rate=lr,
device="cpu",
precision="float32",
seed=42,
num_workers=0,
pin_memory=False,
)
trainer = FixedLoRATrainer(model, adapter_cfg, train_cfg)
step_count = 0
last_loss = 0.0
for update in trainer.train():
if hasattr(update, "step"):
step_count = update.step
last_loss = update.loss
if step_count % 5 == 0:
log_lines.append(f" Step {step_count}: loss={last_loss:.4f}")
elif isinstance(update, tuple) and len(update) >= 2:
step_count = update[0]
last_loss = update[1]
if step_count % 5 == 0:
log_lines.append(f" Step {step_count}: loss={last_loss:.4f}")
log_lines.append(f"Training complete! Final step: {step_count}, loss: {last_loss:.4f}")
log_lines.append(f"LoRA saved to: {output_dir}")
# Cleanup
del model, trainer
gc.collect()
except Exception as e:
import traceback
log_lines.append(f"ERROR: {e}")
log_lines.append(traceback.format_exc())
return "\n".join(log_lines)
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
def build_ui():
theme = gr.themes.Default()
try:
theme = gr.Theme.from_hub("NoCrypt/miku")
except Exception:
pass
with gr.Blocks(
theme=theme,
title="ACE-Step 1.5 CPU",
css="""
.main-title { text-align: center; margin-bottom: 0.5em; }
.status-box { font-family: monospace; font-size: 0.85em; }
""",
) as demo:
gr.HTML("<h1 class='main-title'>ACE-Step 1.5 Music Generation (CPU)</h1>")
gr.HTML(
"<p style='text-align:center;'>Text-to-music generation and LoRA training, "
"running entirely on CPU. Based on "
"<a href='https://github.com/ace-step/ACE-Step-1.5'>ACE-Step 1.5</a>.</p>"
)
with gr.Tabs():
# ---- Generate Tab ----
with gr.Tab("Generate Music"):
with gr.Row():
with gr.Column(scale=2):
caption_input = gr.Textbox(
label="Music Description",
placeholder="e.g. upbeat electronic dance music, 120 BPM",
lines=3,
value="upbeat electronic dance music, energetic synth leads, driving bassline",
)
lyrics_input = gr.Textbox(
label="Lyrics (use [Instrumental] for no vocals)",
placeholder="[Instrumental]",
lines=3,
value="[Instrumental]",
)
instrumental_cb = gr.Checkbox(
label="Instrumental (no vocals)",
value=True,
)
with gr.Column(scale=1):
bpm_input = gr.Number(
label="BPM (0 = auto)",
value=120,
minimum=0,
maximum=300,
)
duration_input = gr.Slider(
label="Duration (seconds)",
minimum=10,
maximum=120,
value=10,
step=5,
)
seed_input = gr.Number(
label="Seed (-1 = random)",
value=-1,
)
steps_input = gr.Slider(
label="Inference Steps (fewer = faster)",
minimum=1,
maximum=32,
value=8,
step=1,
)
lm_size_input = gr.Dropdown(
label="LM Model Size",
choices=["0.6B (fast)", "1.7B (balanced)", "4B (best quality)"],
value="1.7B (balanced)",
info="Language model for music understanding",
)
lora_select = gr.Dropdown(
label="Use Trained LoRA",
choices=get_trained_loras(),
value="None (no LoRA)",
info="Select a LoRA you trained to apply it",
)
generate_btn = gr.Button("Generate Music", variant="primary")
with gr.Row():
audio_output = gr.Audio(
label="Generated Audio",
type="filepath",
)
gen_status = gr.Textbox(
label="Status",
interactive=False,
elem_classes="status-box",
)
generate_btn.click(
fn=generate_music,
inputs=[
caption_input,
lyrics_input,
instrumental_cb,
bpm_input,
duration_input,
seed_input,
steps_input,
lm_size_input,
lora_select,
],
outputs=[audio_output, gen_status],
)
# ---- Train LoRA Tab ----
with gr.Tab("Train LoRA"):
gr.Markdown(
"### Train a LoRA adapter on your audio files\n"
"Upload WAV/MP3/FLAC files to fine-tune the model. "
"Training runs on CPU so keep epochs low and files short."
)
with gr.Row():
with gr.Column():
audio_upload = gr.File(
label="Upload Audio Files",
file_count="multiple",
file_types=["audio"],
)
lora_name_input = gr.Textbox(
label="LoRA Name",
value="my_lora",
)
with gr.Column():
epochs_input = gr.Slider(
label="Epochs",
minimum=1,
maximum=10,
value=1,
step=1,
)
lr_input = gr.Number(
label="Learning Rate",
value=1e-4,
)
rank_input = gr.Slider(
label="LoRA Rank",
minimum=1,
maximum=64,
value=8,
step=1,
)
train_btn = gr.Button("Start Training", variant="primary")
train_log = gr.Textbox(
label="Training Log",
interactive=False,
lines=15,
elem_classes="status-box",
)
def train_and_refresh(*args):
log = train_lora(*args)
new_loras = get_trained_loras()
return log, gr.update(choices=new_loras, value=new_loras[-1] if len(new_loras) > 1 else "None (no LoRA)")
train_btn.click(
fn=train_and_refresh,
inputs=[
audio_upload,
lora_name_input,
epochs_input,
lr_input,
rank_input,
],
outputs=[train_log, lora_select],
)
return demo
if __name__ == "__main__":
demo = build_ui()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
ssr_mode=False,
)