| | from typing import Any, Dict |
| |
|
| | import torch |
| | from diffusers import AudioLDM2Pipeline, DPMSolverMultistepScheduler |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | |
| | self.pipeline = AudioLDM2Pipeline.from_pretrained( |
| | "cvssp/audioldm2-music", torch_dtype=torch.float16 |
| | ) |
| | self.pipeline.to("cuda") |
| | self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config( |
| | self.pipeline.scheduler.config |
| | ) |
| | self.pipeline.enable_model_cpu_offload() |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
| | """ |
| | Args: |
| | data (:dict:): |
| | The payload with the text prompt and generation parameters. |
| | """ |
| | |
| | song_description = data.pop("inputs", data) |
| | duration = data.get("duration", 30) |
| | negative_prompt = data.get("negative_prompt", "Low quality, average quality.") |
| |
|
| | audio = self.pipeline( |
| | song_description, |
| | negative_prompt=negative_prompt, |
| | num_waveforms_per_prompt=4, |
| | audio_length_in_s=duration, |
| | num_inference_steps=20, |
| | ).audios[0] |
| |
|
| | |
| | prediction = audio.tolist() |
| |
|
| | return {"generated_audio": prediction} |
| |
|