|
|
import shutil |
|
|
from fastapi.responses import FileResponse |
|
|
import asyncio |
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks |
|
|
from fastapi.responses import StreamingResponse |
|
|
from pydantic import BaseModel |
|
|
from utils import STT, TTS |
|
|
from data_ingestion import Ingest_Data |
|
|
from RAG import app as rag_app, Ragbot_State, reload_vector_store |
|
|
import os |
|
|
|
|
|
|
|
|
app = FastAPI(title="LangGraph RAG Chatbot", version="1.0") |
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
query: str |
|
|
thread_id: str = "default_user" |
|
|
use_rag: bool = False |
|
|
use_web: bool = False |
|
|
model_name: str = "gpt" |
|
|
|
|
|
class TTSRequest(BaseModel): |
|
|
text: str |
|
|
voice: str = "en-US-AriaNeural" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def health_check(): |
|
|
return {"status": "running", "message": "Bot is ready"} |
|
|
|
|
|
@app.post("/upload") |
|
|
async def upload_document( |
|
|
file: UploadFile = File(...), |
|
|
background_tasks: BackgroundTasks = BackgroundTasks() |
|
|
): |
|
|
try: |
|
|
temp_filename = f"temp_{file.filename}" |
|
|
|
|
|
with open(temp_filename, "wb") as buffer: |
|
|
shutil.copyfileobj(file.file, buffer) |
|
|
|
|
|
def process_and_reload(path): |
|
|
try: |
|
|
result = Ingest_Data(path) |
|
|
print(f"Ingestion Result: {result}") |
|
|
reload_vector_store() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing background task: {e}") |
|
|
finally: |
|
|
if os.path.exists(path): |
|
|
os.remove(path) |
|
|
|
|
|
background_tasks.add_task(process_and_reload, temp_filename) |
|
|
|
|
|
return { |
|
|
"message": "File received. Processing started in background.", |
|
|
"filename": file.filename |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.post("/chat") |
|
|
async def chat_endpoint(request: ChatRequest): |
|
|
""" |
|
|
Standard Chat Endpoint (Non-Streaming). |
|
|
Waits for the LLM to finish and returns the full JSON response. |
|
|
""" |
|
|
try: |
|
|
|
|
|
config = {"configurable": {"thread_id": request.thread_id}} |
|
|
|
|
|
inputs = { |
|
|
"query": request.query, |
|
|
"RAG": request.use_rag, |
|
|
"web_search": request.use_web, |
|
|
"model_name": request.model_name, |
|
|
"context": [], |
|
|
"metadata": [], |
|
|
"web_context": "", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
result = await rag_app.ainvoke(inputs, config=config) |
|
|
|
|
|
|
|
|
last_message = result['response'][-1] |
|
|
|
|
|
|
|
|
return { |
|
|
"response": last_message.content, |
|
|
"thread_id": request.thread_id |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error generation response: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/stt") |
|
|
async def transcribe_audio(file: UploadFile = File(...)): |
|
|
try: |
|
|
return await STT(file) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.post("/tts") |
|
|
async def text_to_speech(req: TTSRequest): |
|
|
try: |
|
|
audio_path = await TTS(req.text, req.voice) |
|
|
return FileResponse(audio_path, media_type="audio/mpeg", filename="output.mp3") |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|