| | from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.staticfiles import StaticFiles |
| | from pydantic import BaseModel |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | import torch |
| | import logging |
| | import aiofiles |
| | import json |
| | from typing import List, Optional |
| | from datetime import datetime |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | app = FastAPI() |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | app.mount("/static", StaticFiles(directory="static"), name="static") |
| |
|
| | |
| | MODEL_NAME = "mistralai/Mistral-8x7B" |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| | model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16) |
| |
|
| | |
| | search_history = [] |
| |
|
| | |
| | class InferenceRequest(BaseModel): |
| | prompt: str |
| | max_length: Optional[int] = 100 |
| |
|
| | class InferenceResponse(BaseModel): |
| | generated_text: str |
| | timestamp: str |
| |
|
| | class SearchHistoryResponse(BaseModel): |
| | history: List[InferenceResponse] |
| |
|
| | |
| | @app.post("/inference") |
| | async def run_inference(request: InferenceRequest): |
| | """Run inference using the AI model.""" |
| | try: |
| | inputs = tokenizer(request.prompt, return_tensors="pt") |
| | outputs = model.generate(inputs.input_ids, max_length=request.max_length) |
| | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | |
| | search_entry = InferenceResponse( |
| | generated_text=generated_text, |
| | timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| | ) |
| | search_history.append(search_entry) |
| |
|
| | logger.info(f"Inference completed for prompt: {request.prompt}") |
| | return search_entry |
| | except Exception as e: |
| | logger.error(f"Error during inference: {e}") |
| | raise HTTPException(status_code=500, detail="Failed to run inference.") |
| |
|
| | @app.get("/search-history") |
| | async def get_search_history(): |
| | """Get the history of all searches.""" |
| | return SearchHistoryResponse(history=search_history) |
| |
|
| | |
| | @app.websocket("/ws") |
| | async def websocket_endpoint(websocket: WebSocket): |
| | await websocket.accept() |
| | try: |
| | while True: |
| | data = await websocket.receive_text() |
| | request = json.loads(data) |
| | prompt = request.get("prompt") |
| | max_length = request.get("max_length", 100) |
| |
|
| | if not prompt: |
| | await websocket.send_text(json.dumps({"error": "Prompt is required."})) |
| | continue |
| |
|
| | |
| | inputs = tokenizer(prompt, return_tensors="pt") |
| | outputs = model.generate(inputs.input_ids, max_length=max_length) |
| | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | |
| | response = { |
| | "generated_text": generated_text, |
| | "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| | } |
| | await websocket.send_text(json.dumps(response)) |
| |
|
| | except WebSocketDisconnect: |
| | logger.info("WebSocket disconnected.") |
| | except Exception as e: |
| | logger.error(f"WebSocket error: {e}") |
| | await websocket.send_text(json.dumps({"error": str(e)})) |
| |
|
| | |
| | @app.get("/") |
| | async def serve_frontend(): |
| | """Serve the frontend HTML file.""" |
| | async with aiofiles.open("static/index.html", mode="r") as file: |
| | return await file.read() |