| """ |
| ACE-Step Engine - Wrapper for ACE-Step 1.5 official architecture |
| Properly integrates AceStepHandler (DiT) and LLMHandler (5Hz LM) |
| """ |
|
|
| import torch |
| from pathlib import Path |
| import logging |
| from typing import Optional, Dict, Any, Tuple |
| import os |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| try: |
| from acestep.handler import AceStepHandler |
| from acestep.llm_inference import LLMHandler |
| from acestep.inference import GenerationParams, GenerationConfig, generate_music |
| from acestep.model_downloader import ensure_main_model, get_checkpoints_dir, check_main_model_exists |
| ACE_STEP_AVAILABLE = True |
| except ImportError as e: |
| logger.warning(f"ACE-Step 1.5 modules not available: {e}") |
| ACE_STEP_AVAILABLE = False |
|
|
|
|
| class ACEStepEngine: |
| """Wrapper engine for ACE-Step 1.5 with custom interface.""" |
|
|
| def __init__(self, config: Dict[str, Any]): |
| """ |
| Initialize ACE-Step engine. |
| |
| Args: |
| config: Configuration dictionary |
| """ |
| self.config = config |
| self._initialized = False |
| self.dit_handler = None |
| self.llm_handler = None |
| |
| logger.info(f"ACE-Step Engine created (GPU will be detected on first use)") |
|
|
| if not ACE_STEP_AVAILABLE: |
| logger.error("ACE-Step 1.5 modules not available") |
| logger.error("Please ensure acestep package is installed in your environment") |
| return |
|
|
| logger.info("✓ ACE-Step Engine created (models will load on first use)") |
|
|
| def _download_checkpoints(self): |
| """Download model checkpoints from HuggingFace if not present.""" |
| checkpoints_dir = get_checkpoints_dir(self.config.get("checkpoint_dir")) |
| |
| |
| if check_main_model_exists(checkpoints_dir): |
| logger.info(f"✓ ACE-Step 1.5 models already exist at {checkpoints_dir}") |
| return |
| |
| logger.info("Downloading ACE-Step 1.5 models from HuggingFace...") |
| logger.info("This may take several minutes (models are ~7GB total)...") |
| |
| try: |
| |
| success, message = ensure_main_model( |
| checkpoints_dir=checkpoints_dir, |
| prefer_source="huggingface" |
| ) |
| |
| if not success: |
| raise RuntimeError(f"Failed to download models: {message}") |
| |
| logger.info(f"✓ {message}") |
| logger.info("✓ All ACE-Step 1.5 models downloaded successfully") |
| |
| except Exception as e: |
| logger.error(f"Failed to download checkpoints: {e}") |
| raise |
|
|
| def _load_models(self): |
| """Initialize and load ACE-Step models.""" |
| try: |
| if not ACE_STEP_AVAILABLE: |
| raise RuntimeError("ACE-Step 1.5 not available") |
|
|
| checkpoint_dir = self.config.get("checkpoint_dir", "./checkpoints") |
| dit_model_path = self.config.get("dit_model_path", "acestep-v15-turbo") |
| lm_model_path = self.config.get("lm_model_path", "acestep-5Hz-lm-1.7B") |
| |
| |
| checkpoints_dir = get_checkpoints_dir(checkpoint_dir) |
| |
| |
| project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| |
| logger.info(f"Initializing DiT handler with model: {dit_model_path}") |
| |
| |
| |
| |
| status_dit, success_dit = self.dit_handler.initialize_service( |
| project_root=project_root, |
| config_path=dit_model_path, |
| device="auto", |
| use_flash_attention=False, |
| compile_model=False, |
| offload_to_cpu=False, |
| ) |
| |
| if not success_dit: |
| raise RuntimeError(f"Failed to initialize DiT: {status_dit}") |
| |
| logger.info(f"✓ DiT initialized: {status_dit}") |
| |
| |
| logger.info(f"Initializing LLM handler with model: {lm_model_path}") |
| |
| status_llm, success_llm = self.llm_handler.initialize( |
| checkpoint_dir=str(checkpoints_dir), |
| lm_model_path=lm_model_path, |
| backend="pt", |
| device="auto", |
| offload_to_cpu=False, |
| ) |
| |
| if not success_llm: |
| logger.warning(f"LLM initialization failed: {status_llm}") |
| logger.warning("Continuing without LLM (DiT-only mode)") |
| else: |
| logger.info(f" LLM initialized: {status_llm}") |
| |
| self._initialized = True |
| logger.info(" ACE-Step engine fully initialized") |
|
|
| except Exception as e: |
| logger.error(f"Failed to initialize models: {e}") |
| raise |
|
|
| def _ensure_models_loaded(self): |
| """Ensure models are loaded (lazy loading for ZeroGPU compatibility).""" |
| if not self._initialized: |
| logger.info("Lazy loading models on first use...") |
| |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info(f"Using device: {self.device}") |
| |
| |
| if self.dit_handler is None: |
| self.dit_handler = AceStepHandler() |
| if self.llm_handler is None: |
| self.llm_handler = LLMHandler() |
| |
| try: |
| |
| self._download_checkpoints() |
| self._load_models() |
| logger.info("✓ Models loaded successfully") |
| except Exception as e: |
| logger.error(f"Failed to load models: {e}") |
| raise |
|
|
| def generate( |
| self, |
| prompt: str, |
| lyrics: Optional[str] = None, |
| duration: int = 30, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| seed: int = -1, |
| style: str = "auto", |
| lora_path: Optional[str] = None |
| ) -> str: |
| """ |
| Generate music using ACE-Step. |
| |
| Args: |
| prompt: Text description of desired music |
| lyrics: Optional lyrics |
| duration: Duration in seconds |
| temperature: Sampling temperature (for LLM) |
| top_p: Nucleus sampling parameter (for LLM) |
| seed: Random seed (-1 for random) |
| style: Music style |
| lora_path: Path to LoRA model if using |
| |
| Returns: |
| Path to generated audio file |
| """ |
| |
| self._ensure_models_loaded() |
|
|
| try: |
| |
| params = GenerationParams( |
| task_type="text2music", |
| caption=prompt, |
| lyrics=lyrics or "", |
| duration=duration, |
| inference_steps=8, |
| seed=seed if seed >= 0 else -1, |
| thinking=True, |
| lm_temperature=temperature, |
| lm_top_p=top_p, |
| ) |
| |
| |
| config = GenerationConfig( |
| batch_size=1, |
| use_random_seed=(seed < 0), |
| audio_format="wav", |
| ) |
| |
| |
| output_dir = self.config.get("output_dir", "outputs") |
| os.makedirs(output_dir, exist_ok=True) |
| |
| logger.info(f"Generating {duration}s audio: {prompt[:50]}...") |
| |
| result = generate_music( |
| dit_handler=self.dit_handler, |
| llm_handler=self.llm_handler, |
| params=params, |
| config=config, |
| save_dir=output_dir, |
| ) |
| |
| if result.audio_paths: |
| output_path = result.audio_paths[0] |
| logger.info(f" Generated: {output_path}") |
| return output_path |
| else: |
| raise RuntimeError("No audio generated") |
|
|
| except Exception as e: |
| logger.error(f"Generation failed: {e}") |
| raise |
|
|
| def generate_clip( |
| self, |
| prompt: str, |
| lyrics: str, |
| duration: int, |
| context_audio: Optional[str] = None, |
| style: str = "auto", |
| temperature: float = 0.7, |
| seed: int = -1 |
| ) -> str: |
| """ |
| Generate audio clip for timeline (with context conditioning). |
| |
| Args: |
| prompt: Text prompt |
| lyrics: Lyrics for this clip |
| duration: Duration in seconds (typically 32) |
| context_audio: Path to previous audio for style conditioning |
| style: Music style |
| temperature: Sampling temperature |
| seed: Random seed |
| |
| Returns: |
| Path to generated clip |
| """ |
| |
| |
| return self.generate( |
| prompt=prompt, |
| lyrics=lyrics, |
| duration=duration, |
| temperature=temperature, |
| seed=seed, |
| style=style |
| ) |
|
|
| def generate_variation(self, audio_path: str, strength: float = 0.5) -> str: |
| """Generate variation of existing audio.""" |
| |
| self._ensure_models_loaded() |
| |
| try: |
| params = GenerationParams( |
| task_type="audio_variation", |
| audio_path=audio_path, |
| audio_cover_strength=strength, |
| inference_steps=8, |
| ) |
| |
| config = GenerationConfig( |
| batch_size=1, |
| audio_format="wav", |
| ) |
| |
| output_dir = self.config.get("output_dir", "outputs") |
| |
| result = generate_music( |
| self.dit_handler, |
| self.llm_handler, |
| params, |
| config, |
| save_dir=output_dir, |
| ) |
| |
| return result.audio_paths[0] if result.audio_paths else audio_path |
|
|
| except Exception as e: |
| logger.error(f"Variation generation failed: {e}") |
| raise |
|
|
| def repaint( |
| self, |
| audio_path: str, |
| start_time: float, |
| end_time: float, |
| new_prompt: str |
| ) -> str: |
| """Repaint specific section of audio.""" |
| if not self._initialized: |
| raise RuntimeError("Engine not initialized") |
| |
| try: |
| params = GenerationParams( |
| task_type="repainting", |
| audio_path=audio_path, |
| caption=new_prompt, |
| repainting_start=start_time, |
| repainting_end=end_time, |
| inference_steps=8, |
| ) |
| |
| config = GenerationConfig( |
| batch_size=1, |
| audio_format="wav", |
| ) |
| |
| output_dir = self.config.get("output_dir", "outputs") |
| |
| result = generate_music( |
| self.dit_handler, |
| self.llm_handler, |
| params, |
| config, |
| save_dir=output_dir, |
| ) |
| |
| return result.audio_paths[0] if result.audio_paths else audio_path |
|
|
| except Exception as e: |
| logger.error(f"Repainting failed: {e}") |
| raise |
|
|
| def edit_lyrics(self, audio_path: str, new_lyrics: str) -> str: |
| """Edit lyrics while maintaining music.""" |
| |
| |
| logger.warning("Lyric editing not fully implemented - regenerating with new lyrics") |
| |
| return self.generate( |
| prompt="Match the style of the reference", |
| lyrics=new_lyrics, |
| duration=30, |
| ) |
|
|
| def is_initialized(self) -> bool: |
| """Check if engine is initialized.""" |
| return self._initialized |
|
|