| |
|
|
| import asyncio |
| import base64 |
| import io |
| import json |
| import time |
| from typing import AsyncGenerator |
|
|
| import numpy as np |
| import soundfile as sf |
| import torch |
| from fastapi import FastAPI, HTTPException |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel |
| from omnivoice import OmniVoice |
|
|
| |
| |
| |
|
|
| app = FastAPI(title="OmniVoice OpenAI-Compatible TTS") |
|
|
| |
| |
| |
|
|
| SAMPLE_RATE = 24000 |
| NUM_CHANNELS = 1 |
| BYTES_PER_SAMPLE = 2 |
|
|
| FRAME_MS = 20 |
|
|
| CHUNK_SIZE = int( |
| SAMPLE_RATE * (FRAME_MS / 1000) * BYTES_PER_SAMPLE * NUM_CHANNELS |
| ) |
|
|
| |
| |
| |
|
|
| FIXED_REF_AUDIO = "ref_audio/women_ref_1.mp3" |
|
|
| FIXED_REF_TEXT = ( |
| "شوفي يا حلوة هالكريم الجديد للبشرة، يخلي وجهك مثل القمر!" |
| ) |
|
|
| FIXED_INSTRUCT = "female, young adult, high pitch" |
|
|
| |
| |
| |
|
|
| model = OmniVoice.from_pretrained( |
| "/home/riftuser/OmniVoice/exp_v1/omnivoice_finetune/checkpoint-5000", |
| device_map="cuda:0", |
| dtype=torch.float16, |
| ) |
|
|
| |
| generation_lock = asyncio.Lock() |
|
|
| |
| |
| |
|
|
| class SpeechRequest(BaseModel): |
|
|
| model: str = "omnivoice" |
|
|
| input: str |
|
|
| speed: float = 1.1 |
|
|
| response_format: str = "pcm" |
|
|
| |
| stream_format: str = "audio" |
|
|
|
|
| |
| |
| |
|
|
| def float32_to_pcm16(audio: np.ndarray) -> bytes: |
|
|
| audio = np.clip(audio, -1, 1) |
|
|
| pcm16 = (audio * 32767).astype(np.int16) |
|
|
| return pcm16.tobytes() |
|
|
|
|
| |
| |
| |
|
|
| async def generate_audio(req: SpeechRequest) -> np.ndarray: |
|
|
| async with generation_lock: |
|
|
| def _generate(): |
|
|
| with torch.inference_mode(): |
| |
| print("*" * 50) |
| print("user text : " , req.input) |
| print("*" * 50) |
|
|
| audio = model.generate( |
| text=req.input, |
| ref_audio=FIXED_REF_AUDIO, |
| ref_text=FIXED_REF_TEXT, |
| instruct=FIXED_INSTRUCT, |
| speed=req.speed, |
| num_step = 30, |
| guidance_scale=2.0, |
| t_shift=0.1, |
| position_temperature=3, |
| layer_penalty_factor=5.0, |
| ) |
|
|
| return audio[0] |
|
|
| return await asyncio.to_thread(_generate) |
|
|
|
|
| |
| |
| |
|
|
| async def audio_stream_generator( |
| req: SpeechRequest, |
| ) -> AsyncGenerator[bytes, None]: |
|
|
| audio = await generate_audio(req) |
|
|
| if req.response_format == "pcm": |
|
|
| pcm_bytes = float32_to_pcm16(audio) |
|
|
| for i in range(0, len(pcm_bytes), CHUNK_SIZE): |
|
|
| yield pcm_bytes[i:i + CHUNK_SIZE] |
|
|
| await asyncio.sleep(0) |
|
|
| elif req.response_format == "wav": |
|
|
| buffer = io.BytesIO() |
|
|
| sf.write( |
| buffer, |
| audio, |
| SAMPLE_RATE, |
| format="WAV", |
| ) |
|
|
| buffer.seek(0) |
|
|
| while True: |
|
|
| chunk = buffer.read(4096) |
|
|
| if not chunk: |
| break |
|
|
| yield chunk |
|
|
| await asyncio.sleep(0) |
|
|
| else: |
|
|
| raise HTTPException( |
| status_code=400, |
| detail=f"Unsupported response_format: {req.response_format}" |
| ) |
|
|
|
|
| |
| |
| |
|
|
| async def sse_stream_generator( |
| req: SpeechRequest, |
| ) -> AsyncGenerator[str, None]: |
|
|
| start_time = time.time() |
|
|
| audio = await generate_audio(req) |
|
|
| generation_time = time.time() - start_time |
|
|
| pcm_bytes = float32_to_pcm16(audio) |
|
|
| for i in range(0, len(pcm_bytes), CHUNK_SIZE): |
|
|
| chunk = pcm_bytes[i:i + CHUNK_SIZE] |
|
|
| b64_chunk = base64.b64encode(chunk).decode("utf-8") |
|
|
| event = { |
| "type": "speech.audio.delta", |
| "delta": b64_chunk, |
| } |
|
|
| yield f"data: {json.dumps(event)}\n\n" |
|
|
| await asyncio.sleep(0) |
|
|
| audio_duration = len(audio) / SAMPLE_RATE |
|
|
| usage = { |
| "input_tokens": len(req.input.split()), |
| "output_tokens": int(audio_duration * 50), |
| } |
|
|
| done_event = { |
| "type": "speech.audio.done", |
| "usage": usage, |
| "metrics": { |
| "generation_time_sec": generation_time, |
| "audio_duration_sec": audio_duration, |
| "rtf": round(generation_time / audio_duration, 4), |
| } |
| } |
|
|
| yield f"data: {json.dumps(done_event)}\n\n" |
|
|
| yield "data: [DONE]\n\n" |
|
|
|
|
| |
| |
| |
|
|
| @app.post("/v1/audio/speech") |
| async def create_speech(req: SpeechRequest): |
|
|
| if req.stream_format == "sse": |
|
|
| return StreamingResponse( |
| sse_stream_generator(req), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| }, |
| ) |
|
|
| media_type = ( |
| "audio/pcm" |
| if req.response_format == "pcm" |
| else "audio/wav" |
| ) |
|
|
| return StreamingResponse( |
| audio_stream_generator(req), |
| media_type=media_type, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/health") |
| async def health(): |
|
|
| return { |
| "status": "ok", |
| "sample_rate": SAMPLE_RATE, |
| "voice": { |
| "ref_audio": FIXED_REF_AUDIO, |
| "instruct": FIXED_INSTRUCT, |
| } |
| } |