visualref_docker / server /retrieval_server_visual.py
bulatkh
Recsys demo based on VLMs + visual embeddings (#4)
5ae5072 unverified
Raw
History Blame Contribute Delete
3.15 kB
import os
from typing import Any, Dict, List, Optional, Union
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from services.retrieval_service import RetrievalServiceVisual
from utils.image_utils import image_to_base64
from utils.utils import load_yaml
app = FastAPI(title="Retrieval Server")
class SearchRequest(BaseModel):
query: str
top_k: int = 5
class SearchResponse(BaseModel):
images: List[str]
image_paths: List[str]
scores: List[float]
success: bool
message: str
class ProcessApplyFeedbackRequest(BaseModel):
query: str
top_k: int
relevant_image_paths: List[str]
annotator_json_boxes_list: List[Any]
fuse_initial_query: bool = False
class ProcessApplyFeedbackResponse(BaseModel):
images: List[str]
image_paths: List[str]
scores: List[float]
success: bool
message: str
retrieval_service: Optional[RetrievalServiceVisual] = None
@app.on_event("startup")
async def startup_event():
global retrieval_service
config_path = os.getenv("CONFIG_PATH", "configs/demo/coco_clip_large.yaml")
config = load_yaml(config_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
retrieval_service = RetrievalServiceVisual(
config=config,
device=device,
)
@app.post("/search", response_model=SearchResponse)
async def search_images(request: SearchRequest):
try:
images, scores, image_paths = retrieval_service.search_images(request.query, request.top_k)
images = [image_to_base64(img) for img in images]
return SearchResponse(
images=images,
image_paths=image_paths,
scores=scores,
success=True,
message="Search completed successfully"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/apply_feedback", response_model=ProcessApplyFeedbackResponse)
async def apply_feedback(request: ProcessApplyFeedbackRequest):
try:
images, scores, image_paths = retrieval_service.process_and_apply_feedback(
query=request.query,
top_k=request.top_k,
relevant_image_paths=request.relevant_image_paths,
annotator_json_boxes_list=request.annotator_json_boxes_list,
fuse_initial_query=request.fuse_initial_query
)
images = [image_to_base64(img) for img in images]
return ProcessApplyFeedbackResponse(
images=images,
image_paths=image_paths,
scores=scores,
success=True,
message="Feedback applied successfully"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy", "gpu_available": torch.cuda.is_available()}
if __name__ == "__main__":
import argparse
import uvicorn
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
port = args.port
uvicorn.run(app, host="0.0.0.0", port=port)