"""Middleware for authentication, authorization, and rate limiting.""" import logging import time from functools import wraps from typing import Dict, Any, Optional from flask import request, jsonify, current_app, g from flask_limiter import Limiter from flask_limiter.util import get_remote_address import redis import jwt from datetime import datetime, timedelta logger = logging.getLogger(__name__) class AuthenticationError(Exception): """Authentication related errors.""" pass class AuthorizationError(Exception): """Authorization related errors.""" pass class RateLimitError(Exception): """Rate limiting related errors.""" pass def create_limiter(app=None): """Create and configure rate limiter.""" limiter = Limiter( key_func=get_remote_address, default_limits=["200 per day", "50 per hour"], storage_uri=None # Will be set from app config ) if app: limiter.init_app(app) return limiter class SimpleAuthManager: """ Simple authentication manager for development/testing. In production, this would be replaced with proper JWT/OAuth implementation. """ def __init__(self, redis_client: Optional[redis.Redis] = None): """Initialize auth manager.""" self.redis_client = redis_client self.session_prefix = "auth_session:" self.user_prefix = "user:" def create_session_token(self, user_id: str, expires_in: int = 3600) -> str: """ Create a simple session token for a user. Args: user_id: User identifier expires_in: Token expiration in seconds Returns: str: Session token """ try: # Create a simple token (in production, use proper JWT) token_data = { 'user_id': user_id, 'created_at': time.time(), 'expires_at': time.time() + expires_in } # For simplicity, use user_id as token (in production, use secure random token) token = f"session_{user_id}_{int(time.time())}" if self.redis_client: # Store token in Redis self.redis_client.setex( f"{self.session_prefix}{token}", expires_in, user_id ) return token except Exception as e: logger.error(f"Error creating session token: {e}") raise AuthenticationError(f"Failed to create session token: {e}") def validate_session_token(self, token: str) -> Optional[str]: """ Validate a session token and return user_id if valid. Args: token: Session token to validate Returns: str: User ID if token is valid, None otherwise """ try: if not token: return None if self.redis_client: # Check Redis for token user_id = self.redis_client.get(f"{self.session_prefix}{token}") if user_id: return user_id.decode('utf-8') # Fallback: simple token validation (for development) if token.startswith('session_'): parts = token.split('_') if len(parts) >= 2: return parts[1] # Return user_id part return None except Exception as e: logger.error(f"Error validating session token: {e}") return None def revoke_session_token(self, token: str) -> bool: """ Revoke a session token. Args: token: Session token to revoke Returns: bool: True if token was revoked, False otherwise """ try: if self.redis_client: result = self.redis_client.delete(f"{self.session_prefix}{token}") return result > 0 return True # For development, always return True except Exception as e: logger.error(f"Error revoking session token: {e}") return False def require_auth(f): """ Authentication decorator for API endpoints. Supports both header-based authentication and session tokens. """ @wraps(f) def decorated_function(*args, **kwargs): try: # Check for session token first auth_header = request.headers.get('Authorization') if auth_header and auth_header.startswith('Bearer '): token = auth_header.split(' ')[1] # Get auth manager from app context redis_client = redis.from_url(current_app.config['REDIS_URL']) auth_manager = SimpleAuthManager(redis_client) user_id = auth_manager.validate_session_token(token) if user_id: g.user_id = user_id request.user_id = user_id return f(*args, **kwargs) # Fallback to simple header-based auth (for development) user_id = request.headers.get('X-User-ID') if user_id: g.user_id = user_id request.user_id = user_id return f(*args, **kwargs) # No valid authentication found return jsonify({ 'error': 'Authentication required', 'message': 'Please provide a valid Authorization header or X-User-ID header' }), 401 except Exception as e: logger.error(f"Authentication error: {e}") return jsonify({ 'error': 'Authentication failed', 'message': 'Invalid authentication credentials' }), 401 return decorated_function def require_session_ownership(f): """ Authorization decorator to ensure user owns the session. Must be used after require_auth. """ @wraps(f) def decorated_function(*args, **kwargs): try: session_id = kwargs.get('session_id') or request.view_args.get('session_id') if not session_id: return jsonify({ 'error': 'Session ID required', 'message': 'Session ID must be provided in the URL' }), 400 user_id = getattr(g, 'user_id', None) or getattr(request, 'user_id', None) if not user_id: return jsonify({ 'error': 'User not authenticated', 'message': 'User authentication required' }), 401 # Import here to avoid circular imports from ..services.session_manager import SessionManager, SessionNotFoundError redis_client = redis.from_url(current_app.config['REDIS_URL']) session_manager = SessionManager(redis_client) try: session = session_manager.get_session(session_id) if session.user_id != user_id: return jsonify({ 'error': 'Access denied', 'message': 'You do not have permission to access this session' }), 403 # Store session in request context for use in endpoint g.session = session request.session = session except SessionNotFoundError: return jsonify({ 'error': 'Session not found', 'message': f'Session {session_id} does not exist' }), 404 return f(*args, **kwargs) except Exception as e: logger.error(f"Authorization error: {e}") return jsonify({ 'error': 'Authorization failed', 'message': 'Failed to verify session ownership' }), 500 return decorated_function def validate_json_request(required_fields: list = None, optional_fields: list = None): """ Decorator to validate JSON request data. Args: required_fields: List of required field names optional_fields: List of optional field names (for documentation) """ def decorator(f): @wraps(f) def decorated_function(*args, **kwargs): if not request.is_json: return jsonify({ 'error': 'Invalid content type', 'message': 'Request must be JSON' }), 400 try: data = request.get_json() except Exception as e: return jsonify({ 'error': 'Invalid JSON', 'message': f'Failed to parse JSON: {str(e)}' }), 400 if not data: return jsonify({ 'error': 'Empty request body', 'message': 'Request body cannot be empty' }), 400 if required_fields: missing_fields = [field for field in required_fields if field not in data] if missing_fields: return jsonify({ 'error': 'Missing required fields', 'message': f'Required fields: {", ".join(missing_fields)}', 'missing_fields': missing_fields }), 400 # Store validated data in request context g.json_data = data request.json_data = data return f(*args, **kwargs) return decorated_function return decorator def handle_rate_limit_exceeded(e): """Handle rate limit exceeded errors.""" return jsonify({ 'error': 'Rate limit exceeded', 'message': 'Too many requests. Please try again later.', 'retry_after': getattr(e, 'retry_after', None) }), 429 def setup_error_handlers(app): """Setup error handlers for the application.""" @app.errorhandler(AuthenticationError) def handle_auth_error(error): return jsonify({ 'error': 'Authentication failed', 'message': str(error) }), 401 @app.errorhandler(AuthorizationError) def handle_authz_error(error): return jsonify({ 'error': 'Authorization failed', 'message': str(error) }), 403 @app.errorhandler(429) def handle_rate_limit(error): return handle_rate_limit_exceeded(error) @app.errorhandler(400) def handle_bad_request(error): return jsonify({ 'error': 'Bad request', 'message': 'The request could not be understood by the server' }), 400 @app.errorhandler(404) def handle_not_found(error): return jsonify({ 'error': 'Not found', 'message': 'The requested resource was not found' }), 404 @app.errorhandler(500) def handle_internal_error(error): logger.error(f"Internal server error: {error}") return jsonify({ 'error': 'Internal server error', 'message': 'An unexpected error occurred' }), 500 class RequestLoggingMiddleware: """Middleware for logging API requests.""" def __init__(self, app=None): self.app = app if app: self.init_app(app) def init_app(self, app): """Initialize the middleware with the Flask app.""" app.before_request(self.log_request) app.after_request(self.log_response) def log_request(self): """Log incoming requests.""" if request.endpoint and not request.endpoint.startswith('static'): logger.info(f"API Request: {request.method} {request.path} from {request.remote_addr}") # Log request data for debugging (be careful with sensitive data) if request.is_json and current_app.debug: try: data = request.get_json() # Remove sensitive fields before logging safe_data = {k: v for k, v in data.items() if k not in ['password', 'token', 'secret']} logger.debug(f"Request data: {safe_data}") except: pass def log_response(self, response): """Log outgoing responses.""" if request.endpoint and not request.endpoint.startswith('static'): logger.info(f"API Response: {response.status_code} for {request.method} {request.path}") return response def create_auth_manager(redis_client: redis.Redis) -> SimpleAuthManager: """ Factory function to create an authentication manager. Args: redis_client: Redis client instance Returns: SimpleAuthManager: Configured auth manager instance """ return SimpleAuthManager(redis_client)