Spaces:
Running
Running
| import asyncio | |
| import logging | |
| from typing import Dict, Any | |
| from fastapi import HTTPException, UploadFile, status, Depends | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from config import Config | |
| from .rag_pipeline import route_and_process_query, add_document_to_rag, check_system_health | |
| from .document_handler import extract_text_from_file | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| security = HTTPBearer() | |
| # Supported file types | |
| SUPPORTED_CONTENT_TYPES = Config.RAG_SUPPORTED_CONTENT_TYPES | |
| MAX_FILE_SIZE = Config.RAG_MAX_FILE_SIZE | |
| MAX_QUERY_LENGTH = Config.RAG_MAX_QUERY_LENGTH | |
| async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
| """Verify Bearer token from Authorization header.""" | |
| token = credentials.credentials | |
| expected_token = Config.SECRET_TOKEN | |
| if not expected_token: | |
| logger.error("MY_SECRET_TOKEN not configured") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Server configuration error" | |
| ) | |
| if token != expected_token: | |
| logger.warning(f"Invalid token attempt: {token[:10]}...") | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Invalid or expired token" | |
| ) | |
| return token | |
| async def handle_rag_query(query: str) -> Dict[str, Any]: | |
| """Handle an incoming query by routing it and getting the appropriate answer.""" | |
| # Input validation | |
| if not query or not query.strip(): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Query cannot be empty" | |
| ) | |
| if len(query) > MAX_QUERY_LENGTH: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Query too long. Please limit to {MAX_QUERY_LENGTH} characters." | |
| ) | |
| try: | |
| logger.info(f"Processing query: {query[:50]}...") | |
| # Process query in thread pool | |
| response = await asyncio.to_thread(route_and_process_query, query) | |
| logger.info(f"Query processed successfully. Route: {response.get('route', 'Unknown')}") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error processing query: {e}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Error processing your query. Please try again." | |
| ) | |
| async def handle_document_upload(file: UploadFile) -> Dict[str, str]: | |
| """Handle uploading a document to the RAG's vector store.""" | |
| # File validation | |
| if not file.filename: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="No file provided" | |
| ) | |
| if file.content_type not in SUPPORTED_CONTENT_TYPES: | |
| raise HTTPException( | |
| status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, | |
| detail=f"Unsupported file type: {file.content_type}. " | |
| f"Supported types: {', '.join(SUPPORTED_CONTENT_TYPES)}" | |
| ) | |
| # Check file size | |
| contents = await file.read() | |
| if len(contents) > MAX_FILE_SIZE: | |
| raise HTTPException( | |
| status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, | |
| detail=f"File too large. Maximum size: {MAX_FILE_SIZE / (1024*1024):.1f}MB" | |
| ) | |
| # Reset file pointer | |
| await file.seek(0) | |
| try: | |
| logger.info(f"Processing file upload: {file.filename}") | |
| # Extract text from file | |
| text = await extract_text_from_file(file) | |
| if not text or not text.strip(): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="The file appears to be empty or could not be read." | |
| ) | |
| if len(text) < 50: # Too short to be meaningful | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="The extracted text is too short to be meaningful." | |
| ) | |
| # Add to RAG system | |
| success = await asyncio.to_thread( | |
| add_document_to_rag, | |
| text, | |
| { | |
| "source": file.filename, | |
| "content_type": file.content_type, | |
| "size": len(contents) | |
| } | |
| ) | |
| if not success: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to add document to the knowledge base" | |
| ) | |
| logger.info(f"Successfully processed file: {file.filename}") | |
| return { | |
| "message": f"Successfully uploaded and processed '{file.filename}'. " | |
| f"It is now available for querying.", | |
| "filename": file.filename, | |
| "text_length": len(text), | |
| "content_type": file.content_type | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error processing file {file.filename}: {e}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Error processing the file. Please try again." | |
| ) | |
| async def handle_health_check() -> Dict[str, Any]: | |
| """Handle health check requests.""" | |
| try: | |
| health_status = await asyncio.to_thread(check_system_health) | |
| if health_status["status"] == "unhealthy": | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Service is currently unhealthy" | |
| ) | |
| return health_status | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Health check failed: {e}") | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Health check failed" | |
| ) |