Claude_Code / api /dependencies.py
Jainish1808's picture
Allow public root/health while keeping API auth protected
d86dd21
"""Dependency injection for FastAPI."""
from urllib.parse import unquote_plus
from fastapi import Depends, HTTPException, Request
from loguru import logger
from config.settings import Settings
from config.settings import get_settings as _get_settings
from providers.base import BaseProvider, ProviderConfig
from providers.common import get_user_facing_error_message
from providers.exceptions import AuthenticationError
from providers.llamacpp import LlamaCppProvider
from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider
# Provider registry: keyed by provider type string, lazily populated
_providers: dict[str, BaseProvider] = {}
def get_settings() -> Settings:
"""Get application settings via dependency injection."""
return _get_settings()
def _create_provider_for_type(provider_type: str, settings: Settings) -> BaseProvider:
"""Construct and return a new provider instance for the given provider type."""
if provider_type == "nvidia_nim":
if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip():
raise AuthenticationError(
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
"Get a key at https://build.nvidia.com/settings/api-keys"
)
config = ProviderConfig(
api_key=settings.nvidia_nim_api_key,
base_url=NVIDIA_NIM_BASE_URL,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
return NvidiaNimProvider(config, nim_settings=settings.nim)
if provider_type == "open_router":
if not settings.open_router_api_key or not settings.open_router_api_key.strip():
raise AuthenticationError(
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
"Get a key at https://openrouter.ai/keys"
)
config = ProviderConfig(
api_key=settings.open_router_api_key,
base_url=OPENROUTER_BASE_URL,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
return OpenRouterProvider(config)
if provider_type == "lmstudio":
config = ProviderConfig(
api_key="lm-studio",
base_url=settings.lm_studio_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
return LMStudioProvider(config)
if provider_type == "llamacpp":
config = ProviderConfig(
api_key="llamacpp",
base_url=settings.llamacpp_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
return LlamaCppProvider(config)
logger.error(
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio', 'llamacpp'",
provider_type,
)
raise ValueError(
f"Unknown provider_type: '{provider_type}'. "
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio', 'llamacpp'"
)
def get_provider_for_type(provider_type: str) -> BaseProvider:
"""Get or create a provider for the given provider type.
Providers are cached in the registry and reused across requests.
"""
if provider_type not in _providers:
try:
_providers[provider_type] = _create_provider_for_type(
provider_type, get_settings()
)
except AuthenticationError as e:
raise HTTPException(
status_code=503, detail=get_user_facing_error_message(e)
) from e
logger.info("Provider initialized: {}", provider_type)
return _providers[provider_type]
def validate_request_api_key(request: Request, settings: Settings) -> None:
"""Validate a request against configured server API key.
Checks `x-api-key` header, `Authorization: Bearer ...`, or query parameter `psw`
against `Settings.anthropic_auth_token`. If `ANTHROPIC_AUTH_TOKEN` is empty, this is a no-op.
Supports Hugging Face Spaces private deployments via query parameter authentication:
- Append `?psw=your-token` to the base URL
- Or `?psw:your-token` (URL-encoded colon becomes %3A)
"""
anthropic_auth_token = getattr(settings, "anthropic_auth_token", None)
if not anthropic_auth_token:
# No API key configured -> allow
return
# Keep Space health/app shell reachable even when API auth is enabled.
if _is_public_probe_request(request):
return
# Allow Hugging Face private Space signed browser requests for UI pages.
# This keeps API routes protected while avoiding 401 on Space shell probes.
if _is_hf_signed_page_request(request):
return
token = None
# Check headers first (preferred)
header = (
request.headers.get("x-api-key")
or request.headers.get("authorization")
or request.headers.get("anthropic-auth-token")
)
if header:
# Support both raw key in X-API-Key and Bearer token in Authorization
token = header
if header.lower().startswith("bearer "):
token = header.split(" ", 1)[1]
# Strip anything after the first colon to handle tokens with appended model names
if token and ":" in token:
token = token.split(":", 1)[0]
else:
token = _extract_query_token(request)
if not token:
raise HTTPException(status_code=401, detail="Missing API key")
if token != anthropic_auth_token:
raise HTTPException(status_code=401, detail="Invalid API key")
def _extract_query_token(request: Request) -> str | None:
"""Extract auth token from query string for private proxy deployments."""
query_params = request.query_params
if "psw" in query_params:
token = query_params["psw"]
if token and ":" in token:
return token.split(":", 1)[0]
return token or None
raw_query_bytes = request.scope.get("query_string", b"")
raw_query = raw_query_bytes.decode("utf-8", errors="ignore")
if not raw_query:
return None
for part in raw_query.split("&"):
if part.startswith("psw:"):
token = unquote_plus(part[len("psw:") :])
if token and ":" in token:
return token.split(":", 1)[0]
return token or None
if part.startswith("psw%3A") or part.startswith("psw%3a"):
token = unquote_plus(part[len("psw%3A") :])
if token and ":" in token:
return token.split(":", 1)[0]
return token or None
return None
def _is_hf_signed_page_request(request: Request) -> bool:
"""Return True for Hugging Face signed browser requests to non-API pages."""
if request.method not in {"GET", "HEAD"}:
return False
if request.url.path.startswith("/v1/"):
return False
if "__sign" not in request.query_params:
return False
accept = request.headers.get("accept", "").lower()
return "text/html" in accept or "*/*" in accept
def _is_public_probe_request(request: Request) -> bool:
"""Return True for read-only public endpoints used by app shells/health checks."""
if request.method not in {"GET", "HEAD"}:
return False
return request.url.path in {"/", "/health"}
def require_api_key(
request: Request, settings: Settings = Depends(get_settings)
) -> None:
"""FastAPI dependency wrapper for API key validation."""
validate_request_api_key(request, settings)
def get_provider() -> BaseProvider:
"""Get or create the default provider (based on MODEL env var).
Backward-compatible convenience for health/root endpoints and tests.
"""
return get_provider_for_type(get_settings().provider_type)
async def cleanup_provider():
"""Cleanup all provider resources."""
global _providers
for provider in _providers.values():
await provider.cleanup()
_providers = {}
logger.debug("Provider cleanup completed")