| import yaml |
|
|
| from llama_cpp import Llama |
| from fastapi import APIRouter, status |
| from fastapi.responses import JSONResponse |
|
|
| from src.modules.dialog_system import ConversationHandler, MessageRole |
| from src.modules.data_models import UserMessage, AnswerMessage |
|
|
| router = APIRouter() |
|
|
| with open('config.yml', 'r') as file: |
| router.config = yaml.safe_load(file) |
|
|
| router.llm = Llama( |
| model_path=router.config['model_path'], |
| n_ctx=int(router.config['context_tokens']), |
| max_answer_len=int(router.config['max_answer_tokens']) |
| ) |
|
|
| router.conversation = ConversationHandler( |
| model=router.llm, |
| message_role=MessageRole |
| ) |
|
|
|
|
| @router.get("v1/service/status", status_code=status.HTTP_200_OK) |
| async def health() -> AnswerMessage: |
| return AnswerMessage(message="OK") |
|
|
|
|
| @router.get("v1/chat/completions", response_model=AnswerMessage) |
| async def chat_completions(user_message: UserMessage) -> AnswerMessage: |
| try: |
| router.conversation.send_message(user_message.prompt) |
| response = router.conversation.generate_reply() |
| return AnswerMessage(message=response) |
| except Exception as e: |
| return JSONResponse(status_code=500, content={"message": str(e)}) |
|
|