| | import logging |
| | from typing import Any, Dict |
| |
|
| | import numpy as np |
| | import torch |
| | from audiocraft.models import MusicGen |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | if torch.cuda.is_available(): |
| | self.device = "cuda" |
| | else: |
| | self.device = "cpu" |
| | |
| | |
| | self.channels = 1 |
| | self.model = MusicGen.get_pretrained( |
| | "facebook/musicgen-large", device=self.device |
| | ) |
| | self.sample_rate = self.model.sample_rate |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, np.ndarray]: |
| | """ |
| | This call function is called by the endpoint. It takes in a payload and returns an audio signal. |
| | The main advantage of this function is that it supports generation of audio in chunks, |
| | so the limitation of 30s audio generation is removed for the model. |
| | The payload should be a dictionary with the following keys: |
| | prompt: The prompt to generate audio for. |
| | generation_params: A dictionary of generation parameters. The following keys are supported: |
| | duration: The duration of audio to generate in seconds. Default: 30 |
| | temperature: The temperature to use for generation. Default: 0.8 |
| | top_p: The top p value to use for generation. Default: 0.0 |
| | top_k: The top k value to use for generation. Default: 250 |
| | cfg_coef: The amount of classifier free guidance to use. Default: 0.0 |
| | These values are passed to the model's set_generation_params function. Other |
| | values can be passed as well if they are supported by the model. |
| | audio_window: The amount of audio to use as prompt for the next chunk. Default: 20 |
| | chunk_size: The size of each chunk in seconds. Default: 30 |
| | |
| | Args: |
| | data (Dict[str, Any]): The payload to generate audio for. |
| | |
| | Raises: |
| | ValueError: If chunk_size is less than audio_window |
| | or if the duration is not a multiple of chunk_size - audio_window |
| | |
| | Returns: |
| | Dict[str, str]: A dictionary with the generated audio. |
| | """ |
| | prompt = data["inputs"] |
| |
|
| | generation_params = data.get("generation_params", {}) |
| |
|
| | duration = generation_params.get("duration", 30) |
| |
|
| | if duration <= 30: |
| | logger.info(f"Generating audio with duration {duration} in one go.") |
| | self.model.set_generation_params(**generation_params) |
| | final_audio = self.model.generate([prompt], progress=True) |
| | else: |
| | logger.info(f"Generating audio with duration {duration} in chunks.") |
| |
|
| | audio_window = data.get("audio_window", 20) |
| | chunk_size = data.get("chunk_size", 30) |
| | continuation = chunk_size - audio_window |
| | final_duration = duration |
| |
|
| | if chunk_size < audio_window: |
| | raise ValueError( |
| | f"Chunk size {chunk_size} must be greater than audio window {audio_window}" |
| | ) |
| |
|
| | if (final_duration - chunk_size) % continuation != 0: |
| | raise ValueError( |
| | f"Duration ({duration} secs) - chunksize ({chunk_size} secs)" |
| | f" must be a multiple of continuation ({continuation} secs)" |
| | ) |
| |
|
| | generation_params["duration"] = chunk_size |
| | self.model.set_generation_params(**generation_params) |
| |
|
| | logger.info( |
| | f"Generating total audio {final_duration} secs with chunks of {chunk_size} secs " |
| | f"and continuation of {continuation} secs." |
| | ) |
| |
|
| | |
| | logger.info(f"Initializing final audio with {chunk_size} secs of audio.") |
| | final_audio = torch.zeros( |
| | ( |
| | self.channels, |
| | self.sample_rate * final_duration, |
| | ), |
| | dtype=torch.float, |
| | ).to(self.device) |
| |
|
| | final_audio[ |
| | :, |
| | : chunk_size * self.sample_rate, |
| | ] = self.model.generate([prompt], progress=True) |
| |
|
| | n_hops = (final_duration - chunk_size) // continuation |
| | for i_hop in range(n_hops): |
| | logger.info(f"Generating audio for hop {i_hop}") |
| |
|
| | prompt_stop = chunk_size + i_hop * continuation |
| | prompt_start = prompt_stop - audio_window |
| |
|
| | audio_prompt = final_audio[ |
| | :, prompt_start * self.sample_rate : prompt_stop * self.sample_rate |
| | ].reshape(1, self.channels, -1) |
| |
|
| | output = self.model.generate_continuation( |
| | audio_prompt, |
| | self.sample_rate, |
| | [prompt], |
| | progress=True, |
| | ) |
| |
|
| | final_audio[ |
| | :, |
| | prompt_stop |
| | * self.sample_rate : (prompt_stop + continuation) |
| | * self.sample_rate, |
| | ] = output[..., audio_window * self.sample_rate :] |
| | logger.info( |
| | f"finished generating audio till {(prompt_stop + continuation)} secs." |
| | ) |
| |
|
| | return {"generated_audio": final_audio.cpu().numpy().transpose()} |
| |
|