| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| from pathlib import Path |
| from typing import Any |
| from urllib.error import HTTPError, URLError |
| from urllib.parse import urlencode |
| from urllib.request import Request, urlopen |
|
|
| DEFAULT_MAX_RESULTS = 20 |
| DEFAULT_TIMEOUT_SEC = 30 |
|
|
| |
| |
| |
| |
| ALLOWED_ENDPOINT_PATTERNS: list[str] = [ |
| |
| r"^/whoami-v2$", |
| r"^/users/[^/]+/overview$", |
| r"^/users/[^/]+/likes$", |
| r"^/users/[^/]+/followers$", |
| r"^/users/[^/]+/following$", |
| |
| r"^/organizations/[^/]+/overview$", |
| r"^/organizations/[^/]+/members$", |
| r"^/organizations/[^/]+/followers$", |
| |
| r"^/(models|datasets|spaces)/[^/]+/[^/]+/discussions$", |
| r"^/(models|datasets|spaces)/[^/]+/[^/]+/discussions/\d+$", |
| r"^/(models|datasets|spaces)/[^/]+/[^/]+/discussions/\d+/comment$", |
| r"^/(models|datasets|spaces)/[^/]+/[^/]+/discussions/\d+/comment/[^/]+/edit$", |
| r"^/(models|datasets|spaces)/[^/]+/[^/]+/discussions/\d+/comment/[^/]+/hide$", |
| r"^/(models|datasets|spaces)/[^/]+/[^/]+/discussions/\d+/status$", |
| |
| r"^/(models|datasets|spaces)/[^/]+/[^/]+/user-access-request/pending$", |
| r"^/(models|datasets|spaces)/[^/]+/[^/]+/user-access-request/accepted$", |
| r"^/(models|datasets|spaces)/[^/]+/[^/]+/user-access-request/rejected$", |
| r"^/(models|datasets|spaces)/[^/]+/[^/]+/user-access-request/handle$", |
| r"^/(models|datasets|spaces)/[^/]+/[^/]+/user-access-request/grant$", |
| |
| r"^/collections$", |
| r"^/collections/[^/]+$", |
| r"^/collections/[^/]+/items$", |
| |
| r"^/(models|datasets|spaces)/[^/]+/[^/]+/auth-check$", |
| |
| r"^/recent-activity$", |
| ] |
|
|
| _COMPILED_PATTERNS: list[re.Pattern[str]] = [ |
| re.compile(p) for p in ALLOWED_ENDPOINT_PATTERNS |
| ] |
|
|
|
|
| def _is_endpoint_allowed(endpoint: str) -> bool: |
| """Return True if endpoint matches any allowed pattern.""" |
| return any(pattern.match(endpoint) for pattern in _COMPILED_PATTERNS) |
|
|
|
|
| def _load_token() -> str | None: |
| |
| |
| try: |
| from fast_agent.mcp.auth.context import request_bearer_token |
|
|
| ctx_token = request_bearer_token.get() |
| if ctx_token: |
| return ctx_token |
| except ImportError: |
| |
| pass |
|
|
| |
| token = os.getenv("HF_TOKEN") |
| if token: |
| return token |
|
|
| |
| token_path = Path.home() / ".cache" / "huggingface" / "token" |
| if token_path.exists(): |
| token_value = token_path.read_text(encoding="utf-8").strip() |
| return token_value or None |
|
|
| return None |
|
|
|
|
| def _max_results_from_env() -> int: |
| raw = os.getenv("HF_MAX_RESULTS") |
| if not raw: |
| return DEFAULT_MAX_RESULTS |
| try: |
| value = int(raw) |
| except ValueError: |
| return DEFAULT_MAX_RESULTS |
| return value if value > 0 else DEFAULT_MAX_RESULTS |
|
|
|
|
| def _normalize_endpoint(endpoint: str) -> str: |
| """Normalize and validate an endpoint path. |
| |
| Checks: |
| - Must be a relative path (not a full URL) |
| - Must be non-empty |
| - No path traversal sequences (..) |
| - Must match the endpoint allowlist |
| """ |
| if endpoint.startswith("http://") or endpoint.startswith("https://"): |
| raise ValueError("Endpoint must be a path relative to /api, not a full URL.") |
| endpoint = endpoint.strip() |
| if not endpoint: |
| raise ValueError("Endpoint must be a non-empty string.") |
|
|
| |
| if ".." in endpoint: |
| raise ValueError("Path traversal sequences (..) are not allowed in endpoints.") |
|
|
| if not endpoint.startswith("/"): |
| endpoint = f"/{endpoint}" |
|
|
| |
| if not _is_endpoint_allowed(endpoint): |
| raise ValueError( |
| f"Endpoint '{endpoint}' is not in the allowed list. " |
| "See ALLOWED_ENDPOINT_PATTERNS for permitted endpoints." |
| ) |
|
|
| return endpoint |
|
|
|
|
| def _normalize_params(params: dict[str, Any] | None) -> dict[str, Any]: |
| if not params: |
| return {} |
| normalized: dict[str, Any] = {} |
| for key, value in params.items(): |
| if value is None: |
| continue |
| if isinstance(value, (list, tuple)): |
| normalized[key] = [str(item) for item in value] |
| else: |
| normalized[key] = str(value) |
| return normalized |
|
|
|
|
| def _build_url(endpoint: str, params: dict[str, Any] | None) -> str: |
| base = os.getenv("HF_ENDPOINT", "https://huggingface.co").rstrip("/") |
| url = f"{base}/api{_normalize_endpoint(endpoint)}" |
| normalized_params = _normalize_params(params) |
| if normalized_params: |
| url = f"{url}?{urlencode(normalized_params, doseq=True)}" |
| return url |
|
|
|
|
| def _request_once( |
| *, |
| url: str, |
| method_upper: str, |
| json_body: dict[str, Any] | None, |
| ) -> tuple[int, Any]: |
| headers = {"Accept": "application/json"} |
| token = _load_token() |
| if token: |
| headers["Authorization"] = f"Bearer {token}" |
|
|
| data = None |
| if method_upper == "POST": |
| headers["Content-Type"] = "application/json" |
| data = json.dumps(json_body or {}).encode("utf-8") |
|
|
| request = Request(url, headers=headers, data=data, method=method_upper) |
|
|
| try: |
| with urlopen(request, timeout=DEFAULT_TIMEOUT_SEC) as response: |
| raw = response.read() |
| status_code = response.status |
| except HTTPError as exc: |
| error_body = exc.read().decode("utf-8", errors="replace") |
| raise RuntimeError(f"HF API error {exc.code} for {url}: {error_body}") from exc |
| except URLError as exc: |
| raise RuntimeError(f"HF API request failed for {url}: {exc}") from exc |
|
|
| try: |
| payload = json.loads(raw) |
| except json.JSONDecodeError: |
| payload = raw.decode("utf-8", errors="replace") |
|
|
| return status_code, payload |
|
|
|
|
| def _get_nested_value(obj: Any, path: str) -> Any: |
| cur = obj |
| for part in [p for p in path.split(".") if p]: |
| if isinstance(cur, dict): |
| if part not in cur: |
| return None |
| cur = cur[part] |
| elif isinstance(cur, list): |
| try: |
| idx = int(part) |
| except ValueError: |
| return None |
| if idx < 0 or idx >= len(cur): |
| return None |
| cur = cur[idx] |
| else: |
| return None |
| return cur |
|
|
|
|
| def _set_nested_value(obj: Any, path: str, value: Any) -> Any: |
| if not path: |
| return value |
| if not isinstance(obj, dict): |
| return obj |
|
|
| parts = [p for p in path.split(".") if p] |
| if not parts: |
| return obj |
|
|
| cur: Any = obj |
| for part in parts[:-1]: |
| if not isinstance(cur, dict): |
| return obj |
| nxt = cur.get(part) |
| if not isinstance(nxt, dict): |
| nxt = {} |
| cur[part] = nxt |
| cur = nxt |
|
|
| if isinstance(cur, dict): |
| cur[parts[-1]] = value |
| return obj |
|
|
|
|
| def _apply_local_refine( |
| payload: Any, |
| *, |
| data_path: str | None, |
| contains: str | None, |
| where: dict[str, Any] | None, |
| fields: list[str] | None, |
| sort_by: str | None, |
| sort_desc: bool, |
| max_items: int | None, |
| offset: int, |
| ) -> tuple[Any, dict[str, Any]]: |
| |
| root_mode = "other" |
| target_path = data_path |
|
|
| if isinstance(payload, list): |
| list_data = payload |
| root_mode = "list" |
| elif isinstance(payload, dict): |
| if target_path: |
| maybe_list = _get_nested_value(payload, target_path) |
| list_data = maybe_list if isinstance(maybe_list, list) else None |
| elif isinstance(payload.get("recentActivity"), list): |
| target_path = "recentActivity" |
| list_data = payload.get("recentActivity") |
| else: |
| list_data = None |
| root_mode = "dict" |
| else: |
| return payload, {"refined": False, "reason": "non-json-or-scalar"} |
|
|
| if list_data is None: |
| return payload, {"refined": False, "reason": "no-list-target"} |
|
|
| original_count = len(list_data) |
| items = list_data |
|
|
| if where: |
| def _matches_where(item: Any) -> bool: |
| if not isinstance(item, dict): |
| return False |
| for key, expected in where.items(): |
| actual = _get_nested_value(item, key) |
| if actual != expected: |
| return False |
| return True |
|
|
| items = [item for item in items if _matches_where(item)] |
|
|
| if contains: |
| needle = contains.lower() |
| items = [ |
| item |
| for item in items |
| if needle in json.dumps(item, ensure_ascii=False).lower() |
| ] |
|
|
| if sort_by: |
| def _sort_key(item: Any) -> Any: |
| value = _get_nested_value(item, sort_by) if isinstance(item, dict) else None |
| return (value is None, value) |
|
|
| items = sorted(items, key=_sort_key, reverse=sort_desc) |
|
|
| if fields: |
| projected: list[dict[str, Any]] = [] |
| for item in items: |
| if not isinstance(item, dict): |
| continue |
| row: dict[str, Any] = {} |
| for field in fields: |
| row[field] = _get_nested_value(item, field) |
| projected.append(row) |
| items = projected |
|
|
| start = max(offset, 0) |
| if max_items is not None: |
| end = start + max(max_items, 0) |
| items = items[start:end] |
| elif start: |
| items = items[start:] |
|
|
| if root_mode == "list": |
| refined_payload: Any = items |
| effective_path = "<root>" |
| else: |
| effective_path = target_path or "recentActivity" |
| refined_payload = dict(payload) |
| _set_nested_value(refined_payload, effective_path, items) |
|
|
| refine_meta = { |
| "refined": True, |
| "data_path": effective_path, |
| "original_count": original_count, |
| "returned_count": len(items), |
| } |
| return refined_payload, refine_meta |
|
|
|
|
| def hf_api_request( |
| endpoint: str, |
| method: str = "GET", |
| params: dict[str, Any] | None = None, |
| json_body: dict[str, Any] | None = None, |
| max_results: int | None = None, |
| offset: int | None = None, |
| auto_paginate: bool | None = False, |
| max_pages: int | None = 1, |
| data_path: str | None = None, |
| contains: str | None = None, |
| where: dict[str, Any] | None = None, |
| fields: list[str] | None = None, |
| sort_by: str | None = None, |
| sort_desc: bool | None = False, |
| max_items: int | None = None, |
| ) -> dict[str, Any]: |
| """ |
| Primary Hub community API tool (GET/POST only). |
| |
| When to use: |
| - User/org intelligence: /users/*, /organizations/* |
| - Collaboration flows: /{repo_type}s/{repo_id}/discussions and discussion details |
| - Gated access workflows: user-access-request endpoints |
| - Collections list/get/create/add-item |
| - Recent activity feed via /recent-activity |
| |
| When NOT to use: |
| - Model/dataset semantic search/ranking |
| - PATCH/DELETE operations (unsupported) |
| |
| Intent-to-parameter guidance: |
| - "latest" or "recent": add params limit and sort_by time if needed |
| - "top N": use max_items or max_results |
| - "mentioning X": use contains |
| - "only fields A/B": use fields projection |
| - Cursor feeds: use auto_paginate=True with max_pages guard |
| |
| Args: |
| endpoint: Endpoint path relative to /api (allowlisted). |
| method: GET or POST only. |
| params: Query parameters. |
| json_body: JSON body for POST. |
| max_results: Client-side list cap. |
| offset: Client-side list offset. |
| auto_paginate: Follow cursor-based pages for GET responses. |
| max_pages: Max pages when auto_paginate=True. |
| data_path: Dot path to target list (e.g. recentActivity). |
| contains: Case-insensitive text match on serialized items. |
| where: Exact-match dict using dot notation keys. |
| fields: Return only selected fields (dot notation supported). |
| sort_by: Dot-notation sort key. |
| sort_desc: Descending sort flag. |
| max_items: Post-filter cap for returned list. |
| |
| Returns: |
| A dict containing request URL, HTTP status, response data, and refine/pagination metadata. |
| """ |
| method_upper = method.upper() |
|
|
| |
| auto_paginate = bool(auto_paginate) if auto_paginate is not None else False |
| sort_desc = bool(sort_desc) if sort_desc is not None else False |
| if max_pages is None: |
| max_pages = 1 |
| if method_upper not in {"GET", "POST"}: |
| raise ValueError("Only GET and POST are allowed for hf_api_request.") |
|
|
| if method_upper == "GET" and json_body is not None: |
| raise ValueError("GET requests do not accept json_body.") |
|
|
| if auto_paginate and method_upper != "GET": |
| raise ValueError("auto_paginate is only supported for GET requests.") |
|
|
| if max_pages < 1: |
| raise ValueError("max_pages must be >= 1.") |
|
|
| req_params = dict(params or {}) |
| url = _build_url(endpoint, req_params) |
| status_code, payload = _request_once( |
| url=url, |
| method_upper=method_upper, |
| json_body=json_body, |
| ) |
|
|
| pages_fetched = 1 |
|
|
| |
| if auto_paginate and isinstance(payload, dict): |
| list_key: str | None = None |
| if data_path: |
| maybe_list = _get_nested_value(payload, data_path) |
| if isinstance(maybe_list, list): |
| list_key = data_path |
| elif isinstance(payload.get("recentActivity"), list): |
| list_key = "recentActivity" |
|
|
| cursor = payload.get("cursor") |
| while list_key and cursor and pages_fetched < max_pages: |
| req_params["cursor"] = cursor |
| page_url = _build_url(endpoint, req_params) |
| _, next_payload = _request_once( |
| url=page_url, |
| method_upper="GET", |
| json_body=None, |
| ) |
|
|
| if not isinstance(next_payload, dict): |
| break |
|
|
| current_items = _get_nested_value(payload, list_key) |
| next_items = _get_nested_value(next_payload, list_key) |
| if not isinstance(current_items, list) or not isinstance(next_items, list): |
| break |
|
|
| _set_nested_value(payload, list_key, current_items + next_items) |
| cursor = next_payload.get("cursor") |
| payload["cursor"] = cursor |
| pages_fetched += 1 |
|
|
| |
| if isinstance(payload, list): |
| limit = max_results if max_results is not None else _max_results_from_env() |
| start = max(offset or 0, 0) |
| end = start + max(limit, 0) |
| payload = payload[start:end] |
|
|
| |
| refine_requested = any( |
| [ |
| data_path is not None, |
| contains is not None, |
| where is not None, |
| fields is not None, |
| sort_by is not None, |
| max_items is not None, |
| ] |
| ) |
|
|
| refine_meta: dict[str, Any] | None = None |
| if refine_requested: |
| payload, refine_meta = _apply_local_refine( |
| payload, |
| data_path=data_path, |
| contains=contains, |
| where=where, |
| fields=fields, |
| sort_by=sort_by, |
| sort_desc=sort_desc, |
| max_items=max_items, |
| offset=max(offset or 0, 0), |
| ) |
|
|
| result = { |
| "url": url, |
| "status": status_code, |
| "data": payload, |
| "pages_fetched": pages_fetched, |
| } |
| if refine_meta is not None: |
| result["refine"] = refine_meta |
| return result |
|
|