| """ |
| ACE-Step Engine - Core generation module |
| Handles interaction with ACE-Step model for music generation |
| """ |
|
|
| import torch |
| import torchaudio |
| from pathlib import Path |
| import logging |
| from typing import Optional, Dict, Any |
| import numpy as np |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ACEStepEngine: |
| """Core engine for ACE-Step music generation.""" |
| |
| def __init__(self, config: Dict[str, Any]): |
| """ |
| Initialize ACE-Step engine. |
| |
| Args: |
| config: Configuration dictionary |
| """ |
| self.config = config |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info(f"ACE-Step Engine initialized on {self.device}") |
| |
| self.model = None |
| self.text_tokenizer = None |
| self.text_encoder = None |
| self.llm_tokenizer = None |
| self.llm = None |
| self.vae = None |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| |
| |
| self.model = AutoModel.from_pretrained( |
| model_path, |
| torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, |
| device_map="auto", |
| low_cpu_mem_usage=True |
| ) |
| |
| self.model.eval() |
| |
| logger.info("✅ ACE-Step model loaded successfully") |
| |
| except Exception as e: |
| logger.error(f"Failed to load models: {e}") |
| raise |
| |
| def generate( |
| self, |
| prompt: str, |
| lyrics: Optional[str] = None, |
| duration: int = 30, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| seed: int = -1, |
| style: str = "auto", |
| lora_path: Optional[str] = None |
| ) -> str: |
| """ |
| Generate music using ACE-Step. |
| |
| Args: |
| prompt: Text description of desired music |
| lyrics: Optional lyrics |
| duration: Duration in seconds |
| temperature: Sampling temperature |
| top_p: Nucleus sampling parameter |
| seed: Random seed (-1 for random) |
| style: Music style |
| lora_path: Path to LoRA model if using |
| |
| Returns: |
| Path to generated audio file |
| """ |
| try: |
| |
| if seed >= 0: |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| |
| |
| if lora_path: |
| self._load_lora(lora_path) |
| |
| |
| input_text = self._prepare_input(prompt, lyrics, style, duration) |
| |
| |
| inputs = self.text_tokenizer( |
| input_text, |
| return_tensors="pt", |
| padding=True, |
| truncation=True |
| ).to(self.device) |
| |
| |
| logger.info(f"Generating {duration}s audio...") |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_length=duration * 50, |
| temperature=temperature, |
| top_p=top_p, |
| do_sample=True, |
| num_return_sequences=1 |
| ) |
| |
| |
| audio_tensor = self._decode_to_audio(outputs) |
| |
| |
| output_path = self._save_audio(audio_tensor, duration) |
| |
| logger.info(f"✅ Generated audio: {output_path}") |
| return str(output_path) |
| |
| except Exception as e: |
| logger.error(f"Generation failed: {e}") |
| raise |
| finally: |
| |
| if lora_path: |
| self._unload_lora() |
| |
| def generate_clip( |
| self, |
| prompt: str, |
| lyrics: str, |
| duration: int, |
| context_audio: Optional[np.ndarray] = None, |
| style: str = "auto", |
| temperature: float = 0.7, |
| seed: int = -1 |
| ) -> str: |
| """ |
| Generate audio clip with context conditioning. |
| Used for timeline-based generation. |
| |
| Args: |
| prompt: Text prompt |
| lyrics: Lyrics for this clip |
| duration: Duration in seconds (typically 32) |
| context_audio: Previous audio for style conditioning |
| style: Music style |
| temperature: Sampling temperature |
| seed: Random seed |
| |
| Returns: |
| Path to generated clip |
| """ |
| try: |
| if seed >= 0: |
| torch.manual_seed(seed) |
| |
| |
| input_text = self._prepare_input(prompt, lyrics, style, duration) |
| |
| |
| context_embedding = None |
| if context_audio is not None: |
| context_embedding = self._encode_audio_context(context_audio) |
| |
| inputs = self.text_tokenizer(input_text, return_tensors="pt").to(self.device) |
| |
| |
| with torch.no_grad(): |
| if context_embedding is not None: |
| outputs = self.model.generate( |
| **inputs, |
| context_embedding=context_embedding, |
| max_length=duration * 50, |
| temperature=temperature, |
| do_sample=True |
| ) |
| else: |
| outputs = self.model.generate( |
| **inputs, |
| max_length=duration * 50, |
| temperature=temperature, |
| do_sample=True |
| ) |
| |
| audio_tensor = self._decode_to_audio(outputs) |
| output_path = self._save_audio(audio_tensor, duration, prefix="clip") |
| |
| return str(output_path) |
| |
| except Exception as e: |
| logger.error(f"Clip generation failed: {e}") |
| raise |
| |
| def generate_variation(self, audio_path: str, strength: float = 0.5) -> str: |
| """Generate variation of existing audio.""" |
| try: |
| |
| audio, sr = torchaudio.load(audio_path) |
| |
| |
| latent = self._encode_audio(audio) |
| |
| |
| noise = torch.randn_like(latent) * strength |
| varied_latent = latent + noise |
| |
| |
| varied_audio = self._decode_from_latent(varied_latent) |
| |
| |
| output_path = self._save_audio(varied_audio, audio.shape[-1] / sr, prefix="variation") |
| return str(output_path) |
| |
| except Exception as e: |
| logger.error(f"Variation generation failed: {e}") |
| raise |
| |
| def repaint( |
| self, |
| audio_path: str, |
| start_time: float, |
| end_time: float, |
| new_prompt: str |
| ) -> str: |
| """Repaint specific section of audio.""" |
| try: |
| |
| audio, sr = torchaudio.load(audio_path) |
| |
| |
| start_frame = int(start_time * sr) |
| end_frame = int(end_time * sr) |
| |
| |
| latent = self._encode_audio(audio) |
| |
| |
| section_duration = end_time - start_time |
| new_section = self.generate( |
| prompt=new_prompt, |
| duration=int(section_duration), |
| temperature=0.8 |
| ) |
| |
| |
| new_audio, _ = torchaudio.load(new_section) |
| |
| |
| result = audio.clone() |
| result[:, start_frame:end_frame] = new_audio[:, :end_frame-start_frame] |
| |
| |
| blend_length = int(0.5 * sr) |
| if start_frame > blend_length: |
| fade_in = torch.linspace(0, 1, blend_length).unsqueeze(0) |
| result[:, start_frame:start_frame+blend_length] = ( |
| result[:, start_frame:start_frame+blend_length] * fade_in + |
| audio[:, start_frame:start_frame+blend_length] * (1 - fade_in) |
| ) |
| |
| if end_frame < audio.shape[-1] - blend_length: |
| fade_out = torch.linspace(1, 0, blend_length).unsqueeze(0) |
| result[:, end_frame-blend_length:end_frame] = ( |
| result[:, end_frame-blend_length:end_frame] * fade_out + |
| audio[:, end_frame-blend_length:end_frame] * (1 - fade_out) |
| ) |
| |
| |
| output_path = self._save_audio(result, audio.shape[-1] / sr, prefix="repainted") |
| return str(output_path) |
| |
| except Exception as e: |
| logger.error(f"Repainting failed: {e}") |
| raise |
| |
| def edit_lyrics(self, audio_path: str, new_lyrics: str) -> str: |
| """Edit lyrics while maintaining music.""" |
| try: |
| |
| |
| |
| |
| |
| |
| audio, sr = torchaudio.load(audio_path) |
| duration = audio.shape[-1] / sr |
| |
| |
| context = self._encode_audio_context(audio.numpy()) |
| |
| |
| result = self.generate( |
| prompt="Match the style of the reference", |
| lyrics=new_lyrics, |
| duration=int(duration), |
| temperature=0.6 |
| ) |
| |
| return result |
| |
| except Exception as e: |
| logger.error(f"Lyric editing failed: {e}") |
| raise |
| |
| def _prepare_input( |
| self, |
| prompt: str, |
| lyrics: Optional[str], |
| style: str, |
| duration: int |
| ) -> str: |
| """Prepare input text for model.""" |
| parts = [] |
| |
| if style and style != "auto": |
| parts.append(f"[STYLE: {style}]") |
| |
| parts.append(f"[DURATION: {duration}s]") |
| parts.append(prompt) |
| |
| if lyrics: |
| parts.append(f"[LYRICS]\n{lyrics}") |
| |
| return " ".join(parts) |
| |
| def _encode_audio(self, audio: torch.Tensor) -> torch.Tensor: |
| """Encode audio to latent space using DCAE.""" |
| |
| return audio |
| |
| def _decode_from_latent(self, latent: torch.Tensor) -> torch.Tensor: |
| """Decode latent to audio using DCAE.""" |
| |
| return latent |
| |
| def _encode_audio_context(self, audio: np.ndarray) -> torch.Tensor: |
| """Encode audio context for conditioning.""" |
| |
| |
| audio_tensor = torch.from_numpy(audio).float().to(self.device) |
| return audio_tensor |
| |
| def _decode_to_audio(self, outputs: torch.Tensor) -> torch.Tensor: |
| """Decode model outputs to audio tensor.""" |
| |
| |
| sample_rate = 44100 |
| duration = outputs.shape[1] / 50 |
| samples = int(duration * sample_rate) |
| |
| |
| audio = torch.randn(2, samples) * 0.1 |
| return audio |
| |
| def _save_audio( |
| self, |
| audio: torch.Tensor, |
| duration: float, |
| prefix: str = "generated" |
| ) -> Path: |
| """Save audio tensor to file.""" |
| output_dir = Path(self.config.get("output_dir", "outputs")) |
| output_dir.mkdir(exist_ok=True) |
| |
| |
| from datetime import datetime |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| filename = f"{prefix}_{timestamp}.wav" |
| output_path = output_dir / filename |
| |
| |
| torchaudio.save( |
| str(output_path), |
| audio, |
| sample_rate=44100, |
| encoding="PCM_S", |
| bits_per_sample=16 |
| ) |
| |
| return output_path |
| |
| def _load_lora(self, lora_path: str): |
| """Load LoRA weights into model.""" |
| try: |
| from peft import PeftModel |
| self.model = PeftModel.from_pretrained(self.model, lora_path) |
| logger.info(f"✅ Loaded LoRA from {lora_path}") |
| except Exception as e: |
| logger.warning(f"Failed to load LoRA: {e}") |
| |
| def _unload_lora(self): |
| """Unload LoRA weights.""" |
| try: |
| if hasattr(self.model, "unload"): |
| self.model.unload() |
| except Exception as e: |
| logger.warning(f"Failed to unload LoRA: {e}") |
|
|