Spaces:
Runtime error
Runtime error
| """Middleware for authentication, authorization, and rate limiting.""" | |
| import logging | |
| import time | |
| from functools import wraps | |
| from typing import Dict, Any, Optional | |
| from flask import request, jsonify, current_app, g | |
| from flask_limiter import Limiter | |
| from flask_limiter.util import get_remote_address | |
| import redis | |
| import jwt | |
| from datetime import datetime, timedelta | |
| logger = logging.getLogger(__name__) | |
| class AuthenticationError(Exception): | |
| """Authentication related errors.""" | |
| pass | |
| class AuthorizationError(Exception): | |
| """Authorization related errors.""" | |
| pass | |
| class RateLimitError(Exception): | |
| """Rate limiting related errors.""" | |
| pass | |
| def create_limiter(app=None): | |
| """Create and configure rate limiter.""" | |
| limiter = Limiter( | |
| key_func=get_remote_address, | |
| default_limits=["200 per day", "50 per hour"], | |
| storage_uri=None # Will be set from app config | |
| ) | |
| if app: | |
| limiter.init_app(app) | |
| return limiter | |
| class SimpleAuthManager: | |
| """ | |
| Simple authentication manager for development/testing. | |
| In production, this would be replaced with proper JWT/OAuth implementation. | |
| """ | |
| def __init__(self, redis_client: Optional[redis.Redis] = None): | |
| """Initialize auth manager.""" | |
| self.redis_client = redis_client | |
| self.session_prefix = "auth_session:" | |
| self.user_prefix = "user:" | |
| def create_session_token(self, user_id: str, expires_in: int = 3600) -> str: | |
| """ | |
| Create a simple session token for a user. | |
| Args: | |
| user_id: User identifier | |
| expires_in: Token expiration in seconds | |
| Returns: | |
| str: Session token | |
| """ | |
| try: | |
| # Create a simple token (in production, use proper JWT) | |
| token_data = { | |
| 'user_id': user_id, | |
| 'created_at': time.time(), | |
| 'expires_at': time.time() + expires_in | |
| } | |
| # For simplicity, use user_id as token (in production, use secure random token) | |
| token = f"session_{user_id}_{int(time.time())}" | |
| if self.redis_client: | |
| # Store token in Redis | |
| self.redis_client.setex( | |
| f"{self.session_prefix}{token}", | |
| expires_in, | |
| user_id | |
| ) | |
| return token | |
| except Exception as e: | |
| logger.error(f"Error creating session token: {e}") | |
| raise AuthenticationError(f"Failed to create session token: {e}") | |
| def validate_session_token(self, token: str) -> Optional[str]: | |
| """ | |
| Validate a session token and return user_id if valid. | |
| Args: | |
| token: Session token to validate | |
| Returns: | |
| str: User ID if token is valid, None otherwise | |
| """ | |
| try: | |
| if not token: | |
| return None | |
| if self.redis_client: | |
| # Check Redis for token | |
| user_id = self.redis_client.get(f"{self.session_prefix}{token}") | |
| if user_id: | |
| return user_id.decode('utf-8') | |
| # Fallback: simple token validation (for development) | |
| if token.startswith('session_'): | |
| parts = token.split('_') | |
| if len(parts) >= 2: | |
| return parts[1] # Return user_id part | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error validating session token: {e}") | |
| return None | |
| def revoke_session_token(self, token: str) -> bool: | |
| """ | |
| Revoke a session token. | |
| Args: | |
| token: Session token to revoke | |
| Returns: | |
| bool: True if token was revoked, False otherwise | |
| """ | |
| try: | |
| if self.redis_client: | |
| result = self.redis_client.delete(f"{self.session_prefix}{token}") | |
| return result > 0 | |
| return True # For development, always return True | |
| except Exception as e: | |
| logger.error(f"Error revoking session token: {e}") | |
| return False | |
| def require_auth(f): | |
| """ | |
| Authentication decorator for API endpoints. | |
| Supports both header-based authentication and session tokens. | |
| """ | |
| def decorated_function(*args, **kwargs): | |
| try: | |
| # Check for session token first | |
| auth_header = request.headers.get('Authorization') | |
| if auth_header and auth_header.startswith('Bearer '): | |
| token = auth_header.split(' ')[1] | |
| # Get auth manager from app context | |
| redis_client = redis.from_url(current_app.config['REDIS_URL']) | |
| auth_manager = SimpleAuthManager(redis_client) | |
| user_id = auth_manager.validate_session_token(token) | |
| if user_id: | |
| g.user_id = user_id | |
| request.user_id = user_id | |
| return f(*args, **kwargs) | |
| # Fallback to simple header-based auth (for development) | |
| user_id = request.headers.get('X-User-ID') | |
| if user_id: | |
| g.user_id = user_id | |
| request.user_id = user_id | |
| return f(*args, **kwargs) | |
| # No valid authentication found | |
| return jsonify({ | |
| 'error': 'Authentication required', | |
| 'message': 'Please provide a valid Authorization header or X-User-ID header' | |
| }), 401 | |
| except Exception as e: | |
| logger.error(f"Authentication error: {e}") | |
| return jsonify({ | |
| 'error': 'Authentication failed', | |
| 'message': 'Invalid authentication credentials' | |
| }), 401 | |
| return decorated_function | |
| def require_session_ownership(f): | |
| """ | |
| Authorization decorator to ensure user owns the session. | |
| Must be used after require_auth. | |
| """ | |
| def decorated_function(*args, **kwargs): | |
| try: | |
| session_id = kwargs.get('session_id') or request.view_args.get('session_id') | |
| if not session_id: | |
| return jsonify({ | |
| 'error': 'Session ID required', | |
| 'message': 'Session ID must be provided in the URL' | |
| }), 400 | |
| user_id = getattr(g, 'user_id', None) or getattr(request, 'user_id', None) | |
| if not user_id: | |
| return jsonify({ | |
| 'error': 'User not authenticated', | |
| 'message': 'User authentication required' | |
| }), 401 | |
| # Import here to avoid circular imports | |
| from ..services.session_manager import SessionManager, SessionNotFoundError | |
| redis_client = redis.from_url(current_app.config['REDIS_URL']) | |
| session_manager = SessionManager(redis_client) | |
| try: | |
| session = session_manager.get_session(session_id) | |
| if session.user_id != user_id: | |
| return jsonify({ | |
| 'error': 'Access denied', | |
| 'message': 'You do not have permission to access this session' | |
| }), 403 | |
| # Store session in request context for use in endpoint | |
| g.session = session | |
| request.session = session | |
| except SessionNotFoundError: | |
| return jsonify({ | |
| 'error': 'Session not found', | |
| 'message': f'Session {session_id} does not exist' | |
| }), 404 | |
| return f(*args, **kwargs) | |
| except Exception as e: | |
| logger.error(f"Authorization error: {e}") | |
| return jsonify({ | |
| 'error': 'Authorization failed', | |
| 'message': 'Failed to verify session ownership' | |
| }), 500 | |
| return decorated_function | |
| def validate_json_request(required_fields: list = None, optional_fields: list = None): | |
| """ | |
| Decorator to validate JSON request data. | |
| Args: | |
| required_fields: List of required field names | |
| optional_fields: List of optional field names (for documentation) | |
| """ | |
| def decorator(f): | |
| def decorated_function(*args, **kwargs): | |
| if not request.is_json: | |
| return jsonify({ | |
| 'error': 'Invalid content type', | |
| 'message': 'Request must be JSON' | |
| }), 400 | |
| try: | |
| data = request.get_json() | |
| except Exception as e: | |
| return jsonify({ | |
| 'error': 'Invalid JSON', | |
| 'message': f'Failed to parse JSON: {str(e)}' | |
| }), 400 | |
| if not data: | |
| return jsonify({ | |
| 'error': 'Empty request body', | |
| 'message': 'Request body cannot be empty' | |
| }), 400 | |
| if required_fields: | |
| missing_fields = [field for field in required_fields if field not in data] | |
| if missing_fields: | |
| return jsonify({ | |
| 'error': 'Missing required fields', | |
| 'message': f'Required fields: {", ".join(missing_fields)}', | |
| 'missing_fields': missing_fields | |
| }), 400 | |
| # Store validated data in request context | |
| g.json_data = data | |
| request.json_data = data | |
| return f(*args, **kwargs) | |
| return decorated_function | |
| return decorator | |
| def handle_rate_limit_exceeded(e): | |
| """Handle rate limit exceeded errors.""" | |
| return jsonify({ | |
| 'error': 'Rate limit exceeded', | |
| 'message': 'Too many requests. Please try again later.', | |
| 'retry_after': getattr(e, 'retry_after', None) | |
| }), 429 | |
| def setup_error_handlers(app): | |
| """Setup error handlers for the application.""" | |
| def handle_auth_error(error): | |
| return jsonify({ | |
| 'error': 'Authentication failed', | |
| 'message': str(error) | |
| }), 401 | |
| def handle_authz_error(error): | |
| return jsonify({ | |
| 'error': 'Authorization failed', | |
| 'message': str(error) | |
| }), 403 | |
| def handle_rate_limit(error): | |
| return handle_rate_limit_exceeded(error) | |
| def handle_bad_request(error): | |
| return jsonify({ | |
| 'error': 'Bad request', | |
| 'message': 'The request could not be understood by the server' | |
| }), 400 | |
| def handle_not_found(error): | |
| return jsonify({ | |
| 'error': 'Not found', | |
| 'message': 'The requested resource was not found' | |
| }), 404 | |
| def handle_internal_error(error): | |
| logger.error(f"Internal server error: {error}") | |
| return jsonify({ | |
| 'error': 'Internal server error', | |
| 'message': 'An unexpected error occurred' | |
| }), 500 | |
| class RequestLoggingMiddleware: | |
| """Middleware for logging API requests.""" | |
| def __init__(self, app=None): | |
| self.app = app | |
| if app: | |
| self.init_app(app) | |
| def init_app(self, app): | |
| """Initialize the middleware with the Flask app.""" | |
| app.before_request(self.log_request) | |
| app.after_request(self.log_response) | |
| def log_request(self): | |
| """Log incoming requests.""" | |
| if request.endpoint and not request.endpoint.startswith('static'): | |
| logger.info(f"API Request: {request.method} {request.path} from {request.remote_addr}") | |
| # Log request data for debugging (be careful with sensitive data) | |
| if request.is_json and current_app.debug: | |
| try: | |
| data = request.get_json() | |
| # Remove sensitive fields before logging | |
| safe_data = {k: v for k, v in data.items() if k not in ['password', 'token', 'secret']} | |
| logger.debug(f"Request data: {safe_data}") | |
| except: | |
| pass | |
| def log_response(self, response): | |
| """Log outgoing responses.""" | |
| if request.endpoint and not request.endpoint.startswith('static'): | |
| logger.info(f"API Response: {response.status_code} for {request.method} {request.path}") | |
| return response | |
| def create_auth_manager(redis_client: redis.Redis) -> SimpleAuthManager: | |
| """ | |
| Factory function to create an authentication manager. | |
| Args: | |
| redis_client: Redis client instance | |
| Returns: | |
| SimpleAuthManager: Configured auth manager instance | |
| """ | |
| return SimpleAuthManager(redis_client) |