| | from __future__ import annotations |
| |
|
| | import json |
| | import os |
| | import re |
| | from pathlib import Path |
| | from typing import Any |
| | from urllib.error import HTTPError, URLError |
| | from urllib.parse import urlencode |
| | from urllib.request import Request, urlopen |
| |
|
| | DEFAULT_MAX_RESULTS = 20 |
| | DEFAULT_TIMEOUT_SEC = 30 |
| |
|
| | |
| | |
| | |
| | |
| | ALLOWED_ENDPOINT_PATTERNS: list[str] = [ |
| | |
| | r"^/whoami-v2$", |
| | r"^/users/[^/]+/overview$", |
| | r"^/users/[^/]+/likes$", |
| | r"^/users/[^/]+/followers$", |
| | r"^/users/[^/]+/following$", |
| | |
| | r"^/organizations/[^/]+/overview$", |
| | r"^/organizations/[^/]+/members$", |
| | r"^/organizations/[^/]+/followers$", |
| | |
| | r"^/(models|datasets|spaces)/[^/]+/[^/]+/discussions$", |
| | r"^/(models|datasets|spaces)/[^/]+/[^/]+/discussions/\d+$", |
| | r"^/(models|datasets|spaces)/[^/]+/[^/]+/discussions/\d+/comment$", |
| | r"^/(models|datasets|spaces)/[^/]+/[^/]+/discussions/\d+/comment/[^/]+/edit$", |
| | r"^/(models|datasets|spaces)/[^/]+/[^/]+/discussions/\d+/comment/[^/]+/hide$", |
| | r"^/(models|datasets|spaces)/[^/]+/[^/]+/discussions/\d+/status$", |
| | |
| | r"^/(models|datasets|spaces)/[^/]+/[^/]+/user-access-request/pending$", |
| | r"^/(models|datasets|spaces)/[^/]+/[^/]+/user-access-request/accepted$", |
| | r"^/(models|datasets|spaces)/[^/]+/[^/]+/user-access-request/rejected$", |
| | r"^/(models|datasets|spaces)/[^/]+/[^/]+/user-access-request/handle$", |
| | r"^/(models|datasets|spaces)/[^/]+/[^/]+/user-access-request/grant$", |
| | |
| | r"^/collections$", |
| | r"^/collections/[^/]+$", |
| | r"^/collections/[^/]+/items$", |
| | |
| | r"^/(models|datasets|spaces)/[^/]+/[^/]+/auth-check$", |
| | |
| | r"^/recent-activity$", |
| | ] |
| |
|
| | _COMPILED_PATTERNS: list[re.Pattern[str]] = [ |
| | re.compile(p) for p in ALLOWED_ENDPOINT_PATTERNS |
| | ] |
| |
|
| |
|
| | def _is_endpoint_allowed(endpoint: str) -> bool: |
| | """Return True if endpoint matches any allowed pattern.""" |
| | return any(pattern.match(endpoint) for pattern in _COMPILED_PATTERNS) |
| |
|
| |
|
| | def _load_token() -> str | None: |
| | |
| | |
| | try: |
| | from fast_agent.mcp.auth.context import request_bearer_token |
| |
|
| | ctx_token = request_bearer_token.get() |
| | if ctx_token: |
| | return ctx_token |
| | except ImportError: |
| | |
| | pass |
| |
|
| | |
| | token = os.getenv("HF_TOKEN") |
| | if token: |
| | return token |
| |
|
| | |
| | token_path = Path.home() / ".cache" / "huggingface" / "token" |
| | if token_path.exists(): |
| | token_value = token_path.read_text(encoding="utf-8").strip() |
| | return token_value or None |
| |
|
| | return None |
| |
|
| |
|
| | def _max_results_from_env() -> int: |
| | raw = os.getenv("HF_MAX_RESULTS") |
| | if not raw: |
| | return DEFAULT_MAX_RESULTS |
| | try: |
| | value = int(raw) |
| | except ValueError: |
| | return DEFAULT_MAX_RESULTS |
| | return value if value > 0 else DEFAULT_MAX_RESULTS |
| |
|
| |
|
| | def _normalize_endpoint(endpoint: str) -> str: |
| | """Normalize and validate an endpoint path. |
| | |
| | Checks: |
| | - Must be a relative path (not a full URL) |
| | - Must be non-empty |
| | - No path traversal sequences (..) |
| | - Must match the endpoint allowlist |
| | """ |
| | if endpoint.startswith("http://") or endpoint.startswith("https://"): |
| | raise ValueError("Endpoint must be a path relative to /api, not a full URL.") |
| | endpoint = endpoint.strip() |
| | if not endpoint: |
| | raise ValueError("Endpoint must be a non-empty string.") |
| |
|
| | |
| | if ".." in endpoint: |
| | raise ValueError("Path traversal sequences (..) are not allowed in endpoints.") |
| |
|
| | if not endpoint.startswith("/"): |
| | endpoint = f"/{endpoint}" |
| |
|
| | |
| | if not _is_endpoint_allowed(endpoint): |
| | raise ValueError( |
| | f"Endpoint '{endpoint}' is not in the allowed list. " |
| | "See ALLOWED_ENDPOINT_PATTERNS for permitted endpoints." |
| | ) |
| |
|
| | return endpoint |
| |
|
| |
|
| | def _normalize_params(params: dict[str, Any] | None) -> dict[str, Any]: |
| | if not params: |
| | return {} |
| | normalized: dict[str, Any] = {} |
| | for key, value in params.items(): |
| | if value is None: |
| | continue |
| | if isinstance(value, (list, tuple)): |
| | normalized[key] = [str(item) for item in value] |
| | else: |
| | normalized[key] = str(value) |
| | return normalized |
| |
|
| |
|
| | def _build_url(endpoint: str, params: dict[str, Any] | None) -> str: |
| | base = os.getenv("HF_ENDPOINT", "https://huggingface.co").rstrip("/") |
| | url = f"{base}/api{_normalize_endpoint(endpoint)}" |
| | normalized_params = _normalize_params(params) |
| | if normalized_params: |
| | url = f"{url}?{urlencode(normalized_params, doseq=True)}" |
| | return url |
| |
|
| |
|
| | def _request_once( |
| | *, |
| | url: str, |
| | method_upper: str, |
| | json_body: dict[str, Any] | None, |
| | ) -> tuple[int, Any]: |
| | headers = {"Accept": "application/json"} |
| | token = _load_token() |
| | if token: |
| | headers["Authorization"] = f"Bearer {token}" |
| |
|
| | data = None |
| | if method_upper == "POST": |
| | headers["Content-Type"] = "application/json" |
| | data = json.dumps(json_body or {}).encode("utf-8") |
| |
|
| | request = Request(url, headers=headers, data=data, method=method_upper) |
| |
|
| | try: |
| | with urlopen(request, timeout=DEFAULT_TIMEOUT_SEC) as response: |
| | raw = response.read() |
| | status_code = response.status |
| | except HTTPError as exc: |
| | error_body = exc.read().decode("utf-8", errors="replace") |
| | raise RuntimeError(f"HF API error {exc.code} for {url}: {error_body}") from exc |
| | except URLError as exc: |
| | raise RuntimeError(f"HF API request failed for {url}: {exc}") from exc |
| |
|
| | try: |
| | payload = json.loads(raw) |
| | except json.JSONDecodeError: |
| | payload = raw.decode("utf-8", errors="replace") |
| |
|
| | return status_code, payload |
| |
|
| |
|
| | def _get_nested_value(obj: Any, path: str) -> Any: |
| | cur = obj |
| | for part in [p for p in path.split(".") if p]: |
| | if isinstance(cur, dict): |
| | if part not in cur: |
| | return None |
| | cur = cur[part] |
| | elif isinstance(cur, list): |
| | try: |
| | idx = int(part) |
| | except ValueError: |
| | return None |
| | if idx < 0 or idx >= len(cur): |
| | return None |
| | cur = cur[idx] |
| | else: |
| | return None |
| | return cur |
| |
|
| |
|
| | def _set_nested_value(obj: Any, path: str, value: Any) -> Any: |
| | if not path: |
| | return value |
| | if not isinstance(obj, dict): |
| | return obj |
| |
|
| | parts = [p for p in path.split(".") if p] |
| | if not parts: |
| | return obj |
| |
|
| | cur: Any = obj |
| | for part in parts[:-1]: |
| | if not isinstance(cur, dict): |
| | return obj |
| | nxt = cur.get(part) |
| | if not isinstance(nxt, dict): |
| | nxt = {} |
| | cur[part] = nxt |
| | cur = nxt |
| |
|
| | if isinstance(cur, dict): |
| | cur[parts[-1]] = value |
| | return obj |
| |
|
| |
|
| | def _apply_local_refine( |
| | payload: Any, |
| | *, |
| | data_path: str | None, |
| | contains: str | None, |
| | where: dict[str, Any] | None, |
| | fields: list[str] | None, |
| | sort_by: str | None, |
| | sort_desc: bool, |
| | max_items: int | None, |
| | offset: int, |
| | ) -> tuple[Any, dict[str, Any]]: |
| | |
| | root_mode = "other" |
| | target_path = data_path |
| |
|
| | if isinstance(payload, list): |
| | list_data = payload |
| | root_mode = "list" |
| | elif isinstance(payload, dict): |
| | if target_path: |
| | maybe_list = _get_nested_value(payload, target_path) |
| | list_data = maybe_list if isinstance(maybe_list, list) else None |
| | elif isinstance(payload.get("recentActivity"), list): |
| | target_path = "recentActivity" |
| | list_data = payload.get("recentActivity") |
| | else: |
| | list_data = None |
| | root_mode = "dict" |
| | else: |
| | return payload, {"refined": False, "reason": "non-json-or-scalar"} |
| |
|
| | if list_data is None: |
| | return payload, {"refined": False, "reason": "no-list-target"} |
| |
|
| | original_count = len(list_data) |
| | items = list_data |
| |
|
| | if where: |
| | def _matches_where(item: Any) -> bool: |
| | if not isinstance(item, dict): |
| | return False |
| | for key, expected in where.items(): |
| | actual = _get_nested_value(item, key) |
| | if actual != expected: |
| | return False |
| | return True |
| |
|
| | items = [item for item in items if _matches_where(item)] |
| |
|
| | if contains: |
| | needle = contains.lower() |
| | items = [ |
| | item |
| | for item in items |
| | if needle in json.dumps(item, ensure_ascii=False).lower() |
| | ] |
| |
|
| | if sort_by: |
| | def _sort_key(item: Any) -> Any: |
| | value = _get_nested_value(item, sort_by) if isinstance(item, dict) else None |
| | return (value is None, value) |
| |
|
| | items = sorted(items, key=_sort_key, reverse=sort_desc) |
| |
|
| | if fields: |
| | projected: list[dict[str, Any]] = [] |
| | for item in items: |
| | if not isinstance(item, dict): |
| | continue |
| | row: dict[str, Any] = {} |
| | for field in fields: |
| | row[field] = _get_nested_value(item, field) |
| | projected.append(row) |
| | items = projected |
| |
|
| | start = max(offset, 0) |
| | if max_items is not None: |
| | end = start + max(max_items, 0) |
| | items = items[start:end] |
| | elif start: |
| | items = items[start:] |
| |
|
| | if root_mode == "list": |
| | refined_payload: Any = items |
| | effective_path = "<root>" |
| | else: |
| | effective_path = target_path or "recentActivity" |
| | refined_payload = dict(payload) |
| | _set_nested_value(refined_payload, effective_path, items) |
| |
|
| | refine_meta = { |
| | "refined": True, |
| | "data_path": effective_path, |
| | "original_count": original_count, |
| | "returned_count": len(items), |
| | } |
| | return refined_payload, refine_meta |
| |
|
| |
|
| | def hf_api_request( |
| | endpoint: str, |
| | method: str = "GET", |
| | params: dict[str, Any] | None = None, |
| | json_body: dict[str, Any] | None = None, |
| | max_results: int | None = None, |
| | offset: int | None = None, |
| | auto_paginate: bool | None = False, |
| | max_pages: int | None = 1, |
| | data_path: str | None = None, |
| | contains: str | None = None, |
| | where: dict[str, Any] | None = None, |
| | fields: list[str] | None = None, |
| | sort_by: str | None = None, |
| | sort_desc: bool | None = False, |
| | max_items: int | None = None, |
| | ) -> dict[str, Any]: |
| | """ |
| | Primary Hub community API tool (GET/POST only). |
| | |
| | When to use: |
| | - User/org intelligence: /users/*, /organizations/* |
| | - Collaboration flows: /{repo_type}s/{repo_id}/discussions and discussion details |
| | - Gated access workflows: user-access-request endpoints |
| | - Collections list/get/create/add-item |
| | - Recent activity feed via /recent-activity |
| | |
| | When NOT to use: |
| | - Model/dataset semantic search/ranking |
| | - PATCH/DELETE operations (unsupported) |
| | |
| | Intent-to-parameter guidance: |
| | - "latest" or "recent": add params limit and sort_by time if needed |
| | - "top N": use max_items or max_results |
| | - "mentioning X": use contains |
| | - "only fields A/B": use fields projection |
| | - Cursor feeds: use auto_paginate=True with max_pages guard |
| | |
| | Args: |
| | endpoint: Endpoint path relative to /api (allowlisted). |
| | method: GET or POST only. |
| | params: Query parameters. |
| | json_body: JSON body for POST. |
| | max_results: Client-side list cap. |
| | offset: Client-side list offset. |
| | auto_paginate: Follow cursor-based pages for GET responses. |
| | max_pages: Max pages when auto_paginate=True. |
| | data_path: Dot path to target list (e.g. recentActivity). |
| | contains: Case-insensitive text match on serialized items. |
| | where: Exact-match dict using dot notation keys. |
| | fields: Return only selected fields (dot notation supported). |
| | sort_by: Dot-notation sort key. |
| | sort_desc: Descending sort flag. |
| | max_items: Post-filter cap for returned list. |
| | |
| | Returns: |
| | A dict containing request URL, HTTP status, response data, and refine/pagination metadata. |
| | """ |
| | method_upper = method.upper() |
| |
|
| | |
| | auto_paginate = bool(auto_paginate) if auto_paginate is not None else False |
| | sort_desc = bool(sort_desc) if sort_desc is not None else False |
| | if max_pages is None: |
| | max_pages = 1 |
| | if method_upper not in {"GET", "POST"}: |
| | raise ValueError("Only GET and POST are allowed for hf_api_request.") |
| |
|
| | if method_upper == "GET" and json_body is not None: |
| | raise ValueError("GET requests do not accept json_body.") |
| |
|
| | if auto_paginate and method_upper != "GET": |
| | raise ValueError("auto_paginate is only supported for GET requests.") |
| |
|
| | if max_pages < 1: |
| | raise ValueError("max_pages must be >= 1.") |
| |
|
| | req_params = dict(params or {}) |
| | url = _build_url(endpoint, req_params) |
| | status_code, payload = _request_once( |
| | url=url, |
| | method_upper=method_upper, |
| | json_body=json_body, |
| | ) |
| |
|
| | pages_fetched = 1 |
| |
|
| | |
| | if auto_paginate and isinstance(payload, dict): |
| | list_key: str | None = None |
| | if data_path: |
| | maybe_list = _get_nested_value(payload, data_path) |
| | if isinstance(maybe_list, list): |
| | list_key = data_path |
| | elif isinstance(payload.get("recentActivity"), list): |
| | list_key = "recentActivity" |
| |
|
| | cursor = payload.get("cursor") |
| | while list_key and cursor and pages_fetched < max_pages: |
| | req_params["cursor"] = cursor |
| | page_url = _build_url(endpoint, req_params) |
| | _, next_payload = _request_once( |
| | url=page_url, |
| | method_upper="GET", |
| | json_body=None, |
| | ) |
| |
|
| | if not isinstance(next_payload, dict): |
| | break |
| |
|
| | current_items = _get_nested_value(payload, list_key) |
| | next_items = _get_nested_value(next_payload, list_key) |
| | if not isinstance(current_items, list) or not isinstance(next_items, list): |
| | break |
| |
|
| | _set_nested_value(payload, list_key, current_items + next_items) |
| | cursor = next_payload.get("cursor") |
| | payload["cursor"] = cursor |
| | pages_fetched += 1 |
| |
|
| | |
| | if isinstance(payload, list): |
| | limit = max_results if max_results is not None else _max_results_from_env() |
| | start = max(offset or 0, 0) |
| | end = start + max(limit, 0) |
| | payload = payload[start:end] |
| |
|
| | |
| | refine_requested = any( |
| | [ |
| | data_path is not None, |
| | contains is not None, |
| | where is not None, |
| | fields is not None, |
| | sort_by is not None, |
| | max_items is not None, |
| | ] |
| | ) |
| |
|
| | refine_meta: dict[str, Any] | None = None |
| | if refine_requested: |
| | payload, refine_meta = _apply_local_refine( |
| | payload, |
| | data_path=data_path, |
| | contains=contains, |
| | where=where, |
| | fields=fields, |
| | sort_by=sort_by, |
| | sort_desc=sort_desc, |
| | max_items=max_items, |
| | offset=max(offset or 0, 0), |
| | ) |
| |
|
| | result = { |
| | "url": url, |
| | "status": status_code, |
| | "data": payload, |
| | "pages_fetched": pages_fetched, |
| | } |
| | if refine_meta is not None: |
| | result["refine"] = refine_meta |
| | return result |
| |
|