"""HuggingFace Hub API wrapper for model discovery and info retrieval.""" import json import time from dataclasses import dataclass, field from typing import Optional from functools import lru_cache import requests HF_API = "https://huggingface.co/api" _session = requests.Session() _session.headers.update({"Accept": "application/json"}) # Simple in-memory cache with TTL _cache: dict[str, tuple[float, any]] = {} CACHE_TTL = 300 # 5 minutes def _cached_get(url: str, token: Optional[str] = None, ttl: int = CACHE_TTL) -> dict: """GET with caching and rate-limit handling.""" now = time.time() if url in _cache and (now - _cache[url][0]) < ttl: return _cache[url][1] headers = {} if token: headers["Authorization"] = f"Bearer {token}" resp = _session.get(url, headers=headers, timeout=15) if resp.status_code == 429: retry = int(resp.headers.get("Retry-After", 5)) time.sleep(retry) resp = _session.get(url, headers=headers, timeout=15) resp.raise_for_status() data = resp.json() _cache[url] = (now, data) return data @dataclass class ModelInfo: """Parsed model information from HF Hub.""" model_id: str model_type: str = "unknown" architectures: list[str] = field(default_factory=list) vocab_size: int = 0 hidden_size: int = 0 intermediate_size: int = 0 num_hidden_layers: int = 0 num_attention_heads: int = 0 num_key_value_heads: int = 0 max_position_embeddings: int = 0 torch_dtype: str = "unknown" pipeline_tag: str = "" tags: list[str] = field(default_factory=list) downloads: int = 0 likes: int = 0 size_bytes: int = 0 gated: bool = False private: bool = False trust_remote_code: bool = False error: Optional[str] = None @property def param_estimate(self) -> str: """Rough parameter count estimate based on architecture.""" if self.size_bytes > 0: # Rough: model files in bf16 ≈ 2 bytes per param params = self.size_bytes / 2 if params > 1e9: return f"{params/1e9:.1f}B" elif params > 1e6: return f"{params/1e6:.0f}M" return "unknown" @property def arch_signature(self) -> str: """Unique signature for architecture matching.""" return f"{self.model_type}|{self.hidden_size}|{self.intermediate_size}" @property def display_name(self) -> str: """Short display name (without org prefix).""" return self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id @property def ram_estimate_gb(self) -> float: """Estimated RAM needed for merging (roughly 2.5x model size for bf16 merge).""" if self.size_bytes > 0: return round(self.size_bytes * 2.5 / (1024**3), 1) return 0.0 def to_dict(self) -> dict: return { "model_id": self.model_id, "model_type": self.model_type, "architectures": self.architectures, "vocab_size": self.vocab_size, "hidden_size": self.hidden_size, "intermediate_size": self.intermediate_size, "num_hidden_layers": self.num_hidden_layers, "num_attention_heads": self.num_attention_heads, "torch_dtype": self.torch_dtype, "pipeline_tag": self.pipeline_tag, "downloads": self.downloads, "likes": self.likes, "param_estimate": self.param_estimate, "ram_estimate_gb": self.ram_estimate_gb, "gated": self.gated, "private": self.private, } def fetch_model_info(model_id: str, token: Optional[str] = None) -> ModelInfo: """Fetch comprehensive model information from HF Hub. Args: model_id: Full model ID (e.g., "Qwen/Qwen2.5-Coder-7B-Instruct") token: Optional HF API token for gated/private models Returns: ModelInfo dataclass with all available information """ info = ModelInfo(model_id=model_id) # Fetch main model info try: data = _cached_get(f"{HF_API}/models/{model_id}", token=token) except requests.exceptions.HTTPError as e: if e.response.status_code == 401: info.error = "Gated or private model — HF token required" info.gated = True elif e.response.status_code == 404: info.error = f"Model not found: {model_id}" else: info.error = f"API error: {e.response.status_code}" return info except Exception as e: info.error = f"Connection error: {str(e)}" return info # Parse basic metadata info.pipeline_tag = data.get("pipeline_tag", "") info.tags = data.get("tags", []) info.downloads = data.get("downloads", 0) info.likes = data.get("likes", 0) info.gated = data.get("gated", False) not in (False, None) info.private = data.get("private", False) # Parse config (architecture details) config = data.get("config", {}) if config: info.model_type = config.get("model_type", "unknown") info.architectures = config.get("architectures", []) # Fetch full config.json for detailed architecture info # (the API endpoint only returns basic config fields) try: full_config = _cached_get( f"https://huggingface.co/{model_id}/resolve/main/config.json", token=token, ) info.model_type = full_config.get("model_type", info.model_type) info.architectures = full_config.get("architectures", info.architectures) info.vocab_size = full_config.get("vocab_size", 0) info.hidden_size = full_config.get("hidden_size", 0) info.intermediate_size = full_config.get("intermediate_size", 0) info.num_hidden_layers = full_config.get("num_hidden_layers", 0) info.num_attention_heads = full_config.get("num_attention_heads", 0) info.num_key_value_heads = full_config.get("num_key_value_heads", 0) info.max_position_embeddings = full_config.get("max_position_embeddings", 0) info.torch_dtype = full_config.get("torch_dtype", "unknown") if "auto_map" in full_config: info.trust_remote_code = True except Exception: # Fall back to basic config from API if config: info.vocab_size = config.get("vocab_size", 0) info.hidden_size = config.get("hidden_size", 0) else: info.error = "Could not fetch config.json — model may need trust_remote_code=True" info.trust_remote_code = True # Estimate total model size from siblings (files) siblings = data.get("siblings", []) total_size = 0 for f in siblings: fname = f.get("rfilename", "") size = f.get("size", 0) or 0 # Count only model weight files if any(fname.endswith(ext) for ext in [".safetensors", ".bin", ".pt", ".pth", ".gguf"]): total_size += size info.size_bytes = total_size return info def search_models( query: str = "", author: str = "", architecture: str = "", limit: int = 20, sort: str = "downloads", token: Optional[str] = None, ) -> list[dict]: """Search HuggingFace Hub for models. Args: query: Search query string author: Filter by author/organization architecture: Filter by model_type (e.g., "llama", "qwen2") limit: Max results to return sort: Sort by "downloads", "likes", "created", "modified" token: Optional HF API token Returns: List of dicts with basic model info """ params = { "limit": min(limit, 100), "sort": sort, "direction": -1, "config": True, } if query: params["search"] = query if author: params["author"] = author url = f"{HF_API}/models" try: data = _cached_get( f"{url}?{'&'.join(f'{k}={v}' for k, v in params.items())}", token=token, ttl=60, # shorter cache for search ) except Exception as e: return [{"error": str(e)}] results = [] for m in data: config = m.get("config", {}) or {} model_type = config.get("model_type", "") # Filter by architecture if specified if architecture and model_type.lower() != architecture.lower(): continue results.append({ "model_id": m.get("modelId", ""), "model_type": model_type, "pipeline_tag": m.get("pipeline_tag", ""), "downloads": m.get("downloads", 0), "likes": m.get("likes", 0), "tags": m.get("tags", [])[:5], }) return results[:limit] def get_popular_base_models(architecture: str = "", token: Optional[str] = None) -> list[dict]: """Get popular base models for a given architecture type. Useful for suggesting base_model in merge configs. """ # Common base models by architecture known_bases = { "llama": [ "meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-2-7b-hf", ], "mistral": [ "mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Mixtral-8x7B-Instruct-v0.1", ], "qwen2": [ "Qwen/Qwen2.5-7B-Instruct", "Qwen/Qwen2.5-14B-Instruct", "Qwen/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-72B-Instruct", ], "gemma2": [ "google/gemma-2-9b-it", "google/gemma-2-27b-it", ], "phi3": [ "microsoft/Phi-3-mini-4k-instruct", "microsoft/Phi-3-medium-4k-instruct", ], } if architecture.lower() in known_bases: return [{"model_id": m} for m in known_bases[architecture.lower()]] # Fallback: search for popular instruct models return search_models( query=f"{architecture} instruct", limit=5, sort="downloads", token=token, )