Spaces:
Sleeping
Sleeping
File size: 9,189 Bytes
00a2010 a540b92 00a2010 d86dd21 00a2010 d86dd21 00a2010 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 | """Dependency injection for FastAPI."""
from urllib.parse import unquote_plus
from fastapi import Depends, HTTPException, Request
from loguru import logger
from config.settings import Settings
from config.settings import get_settings as _get_settings
from providers.base import BaseProvider, ProviderConfig
from providers.common import get_user_facing_error_message
from providers.exceptions import AuthenticationError
from providers.llamacpp import LlamaCppProvider
from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider
# Provider registry: keyed by provider type string, lazily populated
_providers: dict[str, BaseProvider] = {}
def get_settings() -> Settings:
"""Get application settings via dependency injection."""
return _get_settings()
def _create_provider_for_type(provider_type: str, settings: Settings) -> BaseProvider:
"""Construct and return a new provider instance for the given provider type."""
if provider_type == "nvidia_nim":
if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip():
raise AuthenticationError(
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
"Get a key at https://build.nvidia.com/settings/api-keys"
)
config = ProviderConfig(
api_key=settings.nvidia_nim_api_key,
base_url=NVIDIA_NIM_BASE_URL,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
return NvidiaNimProvider(config, nim_settings=settings.nim)
if provider_type == "open_router":
if not settings.open_router_api_key or not settings.open_router_api_key.strip():
raise AuthenticationError(
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
"Get a key at https://openrouter.ai/keys"
)
config = ProviderConfig(
api_key=settings.open_router_api_key,
base_url=OPENROUTER_BASE_URL,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
return OpenRouterProvider(config)
if provider_type == "lmstudio":
config = ProviderConfig(
api_key="lm-studio",
base_url=settings.lm_studio_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
return LMStudioProvider(config)
if provider_type == "llamacpp":
config = ProviderConfig(
api_key="llamacpp",
base_url=settings.llamacpp_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
return LlamaCppProvider(config)
logger.error(
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio', 'llamacpp'",
provider_type,
)
raise ValueError(
f"Unknown provider_type: '{provider_type}'. "
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio', 'llamacpp'"
)
def get_provider_for_type(provider_type: str) -> BaseProvider:
"""Get or create a provider for the given provider type.
Providers are cached in the registry and reused across requests.
"""
if provider_type not in _providers:
try:
_providers[provider_type] = _create_provider_for_type(
provider_type, get_settings()
)
except AuthenticationError as e:
raise HTTPException(
status_code=503, detail=get_user_facing_error_message(e)
) from e
logger.info("Provider initialized: {}", provider_type)
return _providers[provider_type]
def validate_request_api_key(request: Request, settings: Settings) -> None:
"""Validate a request against configured server API key.
Checks `x-api-key` header, `Authorization: Bearer ...`, or query parameter `psw`
against `Settings.anthropic_auth_token`. If `ANTHROPIC_AUTH_TOKEN` is empty, this is a no-op.
Supports Hugging Face Spaces private deployments via query parameter authentication:
- Append `?psw=your-token` to the base URL
- Or `?psw:your-token` (URL-encoded colon becomes %3A)
"""
anthropic_auth_token = getattr(settings, "anthropic_auth_token", None)
if not anthropic_auth_token:
# No API key configured -> allow
return
# Keep Space health/app shell reachable even when API auth is enabled.
if _is_public_probe_request(request):
return
# Allow Hugging Face private Space signed browser requests for UI pages.
# This keeps API routes protected while avoiding 401 on Space shell probes.
if _is_hf_signed_page_request(request):
return
token = None
# Check headers first (preferred)
header = (
request.headers.get("x-api-key")
or request.headers.get("authorization")
or request.headers.get("anthropic-auth-token")
)
if header:
# Support both raw key in X-API-Key and Bearer token in Authorization
token = header
if header.lower().startswith("bearer "):
token = header.split(" ", 1)[1]
# Strip anything after the first colon to handle tokens with appended model names
if token and ":" in token:
token = token.split(":", 1)[0]
else:
token = _extract_query_token(request)
if not token:
raise HTTPException(status_code=401, detail="Missing API key")
if token != anthropic_auth_token:
raise HTTPException(status_code=401, detail="Invalid API key")
def _extract_query_token(request: Request) -> str | None:
"""Extract auth token from query string for private proxy deployments."""
query_params = request.query_params
if "psw" in query_params:
token = query_params["psw"]
if token and ":" in token:
return token.split(":", 1)[0]
return token or None
raw_query_bytes = request.scope.get("query_string", b"")
raw_query = raw_query_bytes.decode("utf-8", errors="ignore")
if not raw_query:
return None
for part in raw_query.split("&"):
if part.startswith("psw:"):
token = unquote_plus(part[len("psw:") :])
if token and ":" in token:
return token.split(":", 1)[0]
return token or None
if part.startswith("psw%3A") or part.startswith("psw%3a"):
token = unquote_plus(part[len("psw%3A") :])
if token and ":" in token:
return token.split(":", 1)[0]
return token or None
return None
def _is_hf_signed_page_request(request: Request) -> bool:
"""Return True for Hugging Face signed browser requests to non-API pages."""
if request.method not in {"GET", "HEAD"}:
return False
if request.url.path.startswith("/v1/"):
return False
if "__sign" not in request.query_params:
return False
accept = request.headers.get("accept", "").lower()
return "text/html" in accept or "*/*" in accept
def _is_public_probe_request(request: Request) -> bool:
"""Return True for read-only public endpoints used by app shells/health checks."""
if request.method not in {"GET", "HEAD"}:
return False
return request.url.path in {"/", "/health"}
def require_api_key(
request: Request, settings: Settings = Depends(get_settings)
) -> None:
"""FastAPI dependency wrapper for API key validation."""
validate_request_api_key(request, settings)
def get_provider() -> BaseProvider:
"""Get or create the default provider (based on MODEL env var).
Backward-compatible convenience for health/root endpoints and tests.
"""
return get_provider_for_type(get_settings().provider_type)
async def cleanup_provider():
"""Cleanup all provider resources."""
global _providers
for provider in _providers.values():
await provider.cleanup()
_providers = {}
logger.debug("Provider cleanup completed")
|