Spaces:
Running
Running
| import os | |
| from typing import Any, Dict, List, Optional, Union | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from services.retrieval_service import RetrievalServiceVisual | |
| from utils.image_utils import image_to_base64 | |
| from utils.utils import load_yaml | |
| app = FastAPI(title="Retrieval Server") | |
| class SearchRequest(BaseModel): | |
| query: str | |
| top_k: int = 5 | |
| class SearchResponse(BaseModel): | |
| images: List[str] | |
| image_paths: List[str] | |
| scores: List[float] | |
| success: bool | |
| message: str | |
| class ProcessApplyFeedbackRequest(BaseModel): | |
| query: str | |
| top_k: int | |
| relevant_image_paths: List[str] | |
| annotator_json_boxes_list: List[Any] | |
| fuse_initial_query: bool = False | |
| class ProcessApplyFeedbackResponse(BaseModel): | |
| images: List[str] | |
| image_paths: List[str] | |
| scores: List[float] | |
| success: bool | |
| message: str | |
| retrieval_service: Optional[RetrievalServiceVisual] = None | |
| async def startup_event(): | |
| global retrieval_service | |
| config_path = os.getenv("CONFIG_PATH", "configs/demo/coco_clip_large.yaml") | |
| config = load_yaml(config_path) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| retrieval_service = RetrievalServiceVisual( | |
| config=config, | |
| device=device, | |
| ) | |
| async def search_images(request: SearchRequest): | |
| try: | |
| images, scores, image_paths = retrieval_service.search_images(request.query, request.top_k) | |
| images = [image_to_base64(img) for img in images] | |
| return SearchResponse( | |
| images=images, | |
| image_paths=image_paths, | |
| scores=scores, | |
| success=True, | |
| message="Search completed successfully" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def apply_feedback(request: ProcessApplyFeedbackRequest): | |
| try: | |
| images, scores, image_paths = retrieval_service.process_and_apply_feedback( | |
| query=request.query, | |
| top_k=request.top_k, | |
| relevant_image_paths=request.relevant_image_paths, | |
| annotator_json_boxes_list=request.annotator_json_boxes_list, | |
| fuse_initial_query=request.fuse_initial_query | |
| ) | |
| images = [image_to_base64(img) for img in images] | |
| return ProcessApplyFeedbackResponse( | |
| images=images, | |
| image_paths=image_paths, | |
| scores=scores, | |
| success=True, | |
| message="Feedback applied successfully" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| return {"status": "healthy", "gpu_available": torch.cuda.is_available()} | |
| if __name__ == "__main__": | |
| import argparse | |
| import uvicorn | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--port", type=int, default=8000) | |
| args = parser.parse_args() | |
| port = args.port | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |