| import os |
| import wave |
| import torch |
| import string |
| import random |
| import uvicorn |
| import numpy as np |
| from io import BytesIO |
| from TTS.api import TTS |
| from fastapi import FastAPI, UploadFile |
| from scipy.io.wavfile import write |
| from fastapi.responses import Response, JSONResponse |
|
|
| os.makedirs("temp/", exist_ok = True) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device}", flush = True) |
|
|
| MODEL_PATH = "models/best_model.pth" |
| CONFIG_PATH = "models/config.json" |
|
|
| print(f"Loading model", flush = True) |
| tts = TTS( |
| model_path=MODEL_PATH, |
| config_path=CONFIG_PATH, |
| progress_bar=False, |
| ).to(device) |
|
|
| sample_rate = 22050 |
|
|
| languageCODE = { |
| "bhojpuri": "bho", |
| "bengali": "bn", |
| "english": "en", |
| "gujarati": "gu", |
| "hindi": "hi", |
| "chhattisgarhi": "hne", |
| "kannada": "kn", |
| "magahi": "mag", |
| "maithili": "mai", |
| "marathi": "mr", |
| "telugu": "te" |
| } |
|
|
| app = FastAPI() |
| @app.get("/") |
| def Is_alive(): |
| return {"message" : "Server is Live"} |
|
|
| @app.get("/Get_Inference") |
| async def Inference(text : str, lang : str, speaker_wav : UploadFile): |
|
|
| if not text or not lang or not speaker_wav: |
| return JSONResponse({"comment" : "Missing Field."}, status_code = 422) |
| |
| lan = lang.lower() |
|
|
| if not speaker_wav: |
| return JSONResponse({"comment" : "Speaker file not provided."}, status_code = 422) |
| |
| if lan not in languageCODE: |
| if lan not in languageCODE.values(): |
| return JSONResponse({"comment" : "Language not present in the system."}, status_code = 422) |
|
|
| random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=5)) |
| |
| speaker_wav_filename = "temp/" + random_string + "_" + speaker_wav.filename |
|
|
| with open(speaker_wav_filename , "rb") as wavFile: |
| wavFile.write(await speaker_wav.file.read()) |
|
|
| try: |
| with wave.open(speaker_wav_filename) as temper: |
| pass |
| except: |
| return JSONResponse({"comment" : "Audio file format not supported."}, status_code = 422) |
|
|
| wav = np.array(tts.tts(text=text, speaker_wav = speaker_wav_filename, language = languageCODE[lan])) |
| wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) |
| wav_norm = wav_norm.astype(np.int16) |
| |
| wav_buffer = BytesIO() |
| write(wav_buffer, sample_rate, wav_norm) |
| wav_buffer.seek(0) |
| wav_buffer.name = lang + "_" + speaker_wav.filename + "_" + ''.join(random.choice(string.ascii_uppercase + string.digits + string.ascii_lowercase) for _ in range(7)) + ".wav" |
| |
| os.remove(speaker_wav_filename) |
|
|
| return Response(wav_buffer.read()) |
|
|
|
|
| def start_server(): |
| print('Starting Server...') |
|
|
| uvicorn.run( |
| "API_Main:app", |
| host = "0.0.0.0", |
| port = 8080, |
| log_level="debug", |
| reload=False, |
| ) |
|
|
| if __name__ == "__main__": |
| start_server() |