| import asyncio |
| import uuid |
|
|
| import uvloop |
| from dotenv import load_dotenv |
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| from sse_starlette.sse import EventSourceResponse, ServerSentEvent |
| from starlette.responses import JSONResponse, Response |
|
|
| from .activity_helpers import DONE |
| from .complete_chat import complete_chat_router |
| from .predict import predict_router |
|
|
| try: |
| import infiagent |
| from infiagent.schemas import FailedResponseBaseData |
| from infiagent.utils import get_logger, init_logging, log_id_var |
| except ImportError: |
| print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent") |
| from ..schemas import FailedResponseBaseData |
| from ..utils import get_logger, init_logging, log_id_var |
|
|
| asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) |
|
|
| SSE_API_PATHS = ["/complete_sse"] |
| LOG_ID_HEADER_NAME = "X-Tt-Logid" |
|
|
|
|
| load_dotenv() |
| init_logging() |
| logger = get_logger() |
|
|
| app = FastAPI() |
| origins = ["*"] |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=origins, |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
| app.include_router(complete_chat_router) |
| app.include_router(predict_router) |
|
|
|
|
| @app.middleware("http") |
| async def log_id_middleware(request: Request, call_next): |
| |
| log_id = request.headers.get(LOG_ID_HEADER_NAME) |
| if not log_id: |
| |
| log_id = str(uuid.uuid4()) |
|
|
| log_id_var.set(log_id) |
|
|
| response: Response = await call_next(request) |
| response.headers[LOG_ID_HEADER_NAME] = log_id_var.get() |
| return response |
|
|
|
|
| @app.exception_handler(Exception) |
| async def general_exception_handler(request, exc): |
| error_msg = "Failed to handle request. Internal Server error: {}".format(str(exc)) |
| logger.error(error_msg, exc_info=True) |
|
|
| if request.url.path in SSE_API_PATHS: |
| return EventSourceResponse(ServerSentEvent(data=DONE)) |
| else: |
| return JSONResponse( |
| status_code=500, |
| content={ |
| "response": error_msg, |
| "ResponseBase": FailedResponseBaseData().dict() |
| } |
| ) |
|
|
|
|
| @app.exception_handler(HTTPException) |
| async def http_exception_handler(request, exc): |
| error_msg = "Failed to handle request. Error: {}".format(exc.detail) |
| logger.error(error_msg, exc_info=True) |
|
|
| if request.url.path in SSE_API_PATHS: |
| return EventSourceResponse(ServerSentEvent(data=DONE)) |
| else: |
| return JSONResponse( |
| status_code=exc.status_code, |
| content={ |
| "response": error_msg, |
| "ResponseBase": FailedResponseBaseData().dict() |
| } |
| ) |
|
|