| | import argparse |
| | import gc |
| | import logging |
| | import queue |
| | import socket |
| | import struct |
| | import threading |
| | import traceback |
| | import wave |
| | from importlib.resources import files |
| |
|
| | import numpy as np |
| | import torch |
| | import torchaudio |
| | from huggingface_hub import hf_hub_download |
| | from hydra.utils import get_class |
| | from omegaconf import OmegaConf |
| |
|
| | from f5_tts.infer.utils_infer import ( |
| | chunk_text, |
| | infer_batch_process, |
| | load_model, |
| | load_vocoder, |
| | preprocess_ref_audio_text, |
| | ) |
| |
|
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class AudioFileWriterThread(threading.Thread): |
| | """Threaded file writer to avoid blocking the TTS streaming process.""" |
| |
|
| | def __init__(self, output_file, sampling_rate): |
| | super().__init__() |
| | self.output_file = output_file |
| | self.sampling_rate = sampling_rate |
| | self.queue = queue.Queue() |
| | self.stop_event = threading.Event() |
| | self.audio_data = [] |
| |
|
| | def run(self): |
| | """Process queued audio data and write it to a file.""" |
| | logger.info("AudioFileWriterThread started.") |
| | with wave.open(self.output_file, "wb") as wf: |
| | wf.setnchannels(1) |
| | wf.setsampwidth(2) |
| | wf.setframerate(self.sampling_rate) |
| |
|
| | while not self.stop_event.is_set() or not self.queue.empty(): |
| | try: |
| | chunk = self.queue.get(timeout=0.1) |
| | if chunk is not None: |
| | chunk = np.int16(chunk * 32767) |
| | self.audio_data.append(chunk) |
| | wf.writeframes(chunk.tobytes()) |
| | except queue.Empty: |
| | continue |
| |
|
| | def add_chunk(self, chunk): |
| | """Add a new chunk to the queue.""" |
| | self.queue.put(chunk) |
| |
|
| | def stop(self): |
| | """Stop writing and ensure all queued data is written.""" |
| | self.stop_event.set() |
| | self.join() |
| | logger.info("Audio writing completed.") |
| |
|
| |
|
| | class TTSStreamingProcessor: |
| | def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32): |
| | self.device = device or ( |
| | "cuda" |
| | if torch.cuda.is_available() |
| | else "xpu" |
| | if torch.xpu.is_available() |
| | else "mps" |
| | if torch.backends.mps.is_available() |
| | else "cpu" |
| | ) |
| | model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) |
| | self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") |
| | self.model_arc = model_cfg.model.arch |
| | self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type |
| | self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate |
| |
|
| | self.model = self.load_ema_model(ckpt_file, vocab_file, dtype) |
| | self.vocoder = self.load_vocoder_model() |
| |
|
| | self.update_reference(ref_audio, ref_text) |
| | self._warm_up() |
| | self.file_writer_thread = None |
| | self.first_package = True |
| |
|
| | def load_ema_model(self, ckpt_file, vocab_file, dtype): |
| | return load_model( |
| | self.model_cls, |
| | self.model_arc, |
| | ckpt_path=ckpt_file, |
| | mel_spec_type=self.mel_spec_type, |
| | vocab_file=vocab_file, |
| | ode_method="euler", |
| | use_ema=True, |
| | device=self.device, |
| | ).to(self.device, dtype=dtype) |
| |
|
| | def load_vocoder_model(self): |
| | return load_vocoder(vocoder_name=self.mel_spec_type, is_local=False, local_path=None, device=self.device) |
| |
|
| | def update_reference(self, ref_audio, ref_text): |
| | self.ref_audio, self.ref_text = preprocess_ref_audio_text(ref_audio, ref_text) |
| | self.audio, self.sr = torchaudio.load(self.ref_audio) |
| |
|
| | ref_audio_duration = self.audio.shape[-1] / self.sr |
| | ref_text_byte_len = len(self.ref_text.encode("utf-8")) |
| | self.max_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration)) |
| | self.few_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 2) |
| | self.min_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 4) |
| |
|
| | def _warm_up(self): |
| | logger.info("Warming up the model...") |
| | gen_text = "Warm-up text for the model." |
| | for _ in infer_batch_process( |
| | (self.audio, self.sr), |
| | self.ref_text, |
| | [gen_text], |
| | self.model, |
| | self.vocoder, |
| | progress=None, |
| | device=self.device, |
| | streaming=True, |
| | ): |
| | pass |
| | logger.info("Warm-up completed.") |
| |
|
| | def generate_stream(self, text, conn): |
| | text_batches = chunk_text(text, max_chars=self.max_chars) |
| | if self.first_package: |
| | text_batches = chunk_text(text_batches[0], max_chars=self.few_chars) + text_batches[1:] |
| | text_batches = chunk_text(text_batches[0], max_chars=self.min_chars) + text_batches[1:] |
| | self.first_package = False |
| |
|
| | audio_stream = infer_batch_process( |
| | (self.audio, self.sr), |
| | self.ref_text, |
| | text_batches, |
| | self.model, |
| | self.vocoder, |
| | progress=None, |
| | device=self.device, |
| | streaming=True, |
| | chunk_size=2048, |
| | ) |
| |
|
| | |
| | if self.file_writer_thread is not None: |
| | self.file_writer_thread.stop() |
| | self.file_writer_thread = AudioFileWriterThread("output.wav", self.sampling_rate) |
| | self.file_writer_thread.start() |
| |
|
| | for audio_chunk, _ in audio_stream: |
| | if len(audio_chunk) > 0: |
| | logger.info(f"Generated audio chunk of size: {len(audio_chunk)}") |
| |
|
| | |
| | conn.sendall(struct.pack(f"{len(audio_chunk)}f", *audio_chunk)) |
| |
|
| | |
| | self.file_writer_thread.add_chunk(audio_chunk) |
| |
|
| | logger.info("Finished sending audio stream.") |
| | conn.sendall(b"END") |
| |
|
| | |
| | self.file_writer_thread.stop() |
| |
|
| |
|
| | def handle_client(conn, processor): |
| | try: |
| | with conn: |
| | conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
| | while True: |
| | data = conn.recv(1024) |
| | if not data: |
| | processor.first_package = True |
| | break |
| | data_str = data.decode("utf-8").strip() |
| | logger.info(f"Received text: {data_str}") |
| |
|
| | try: |
| | processor.generate_stream(data_str, conn) |
| | except Exception as inner_e: |
| | logger.error(f"Error during processing: {inner_e}") |
| | traceback.print_exc() |
| | break |
| | except Exception as e: |
| | logger.error(f"Error handling client: {e}") |
| | traceback.print_exc() |
| |
|
| |
|
| | def start_server(host, port, processor): |
| | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| | s.bind((host, port)) |
| | s.listen() |
| | logger.info(f"Server started on {host}:{port}") |
| | while True: |
| | conn, addr = s.accept() |
| | logger.info(f"Connected by {addr}") |
| | handle_client(conn, processor) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument("--host", default="0.0.0.0") |
| | parser.add_argument("--port", default=9998) |
| |
|
| | parser.add_argument( |
| | "--model", |
| | default="F5TTS_v1_Base", |
| | help="The model name, e.g. F5TTS_v1_Base", |
| | ) |
| | parser.add_argument( |
| | "--ckpt_file", |
| | default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")), |
| | help="Path to the model checkpoint file", |
| | ) |
| | parser.add_argument( |
| | "--vocab_file", |
| | default="", |
| | help="Path to the vocab file if customized", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--ref_audio", |
| | default=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), |
| | help="Reference audio to provide model with speaker characteristics", |
| | ) |
| | parser.add_argument( |
| | "--ref_text", |
| | default="", |
| | help="Reference audio subtitle, leave empty to auto-transcribe", |
| | ) |
| |
|
| | parser.add_argument("--device", default=None, help="Device to run the model on") |
| | parser.add_argument("--dtype", default=torch.float32, help="Data type to use for model inference") |
| |
|
| | args = parser.parse_args() |
| |
|
| | try: |
| | |
| | processor = TTSStreamingProcessor( |
| | model=args.model, |
| | ckpt_file=args.ckpt_file, |
| | vocab_file=args.vocab_file, |
| | ref_audio=args.ref_audio, |
| | ref_text=args.ref_text, |
| | device=args.device, |
| | dtype=args.dtype, |
| | ) |
| |
|
| | |
| | start_server(args.host, args.port, processor) |
| |
|
| | except KeyboardInterrupt: |
| | gc.collect() |
| |
|