Spaces:
Sleeping
Sleeping
| """Download utilities for SDXL Model Merger with Gradio progress integration.""" | |
| import re | |
| import requests | |
| from pathlib import Path | |
| from tqdm import tqdm as TqdmBase | |
| from .config import download_cancelled | |
| def extract_model_id(url: str) -> str | None: | |
| """Extract CivitAI model ID from URL.""" | |
| match = re.search(r'/models/(\d+)', url) | |
| return match.group(1) if match else None | |
| def is_huggingface_url(url: str) -> bool: | |
| """Check if URL is a HuggingFace model download URL.""" | |
| return "huggingface.co" in url.lower() | |
| def get_safe_filename_from_url( | |
| url: str, | |
| default_name: str = "model.safetensors", | |
| suffix: str = "", | |
| type_prefix: str | None = None | |
| ) -> str: | |
| """ | |
| Generate a safe filename with model ID from URL. | |
| For CivitAI URLs like https://civitai.com/api/download/models/12345?type=... | |
| Naming patterns: | |
| - Checkpoint (type_prefix='model'): 12345_model.safetensors or 12345_model_anime_style.safetensors | |
| - VAE (suffix='_vae'): 12345_vae.safetensors (no name extraction to avoid double suffix) | |
| - LoRA (suffix='_lora'): 12345_lora.safetensors (no name extraction to avoid double suffix) | |
| For HuggingFace URLs without model IDs, attempts to extract name from path or uses suffix-based naming. | |
| Args: | |
| url: The download URL | |
| default_name: Fallback filename if extraction fails | |
| suffix: Optional suffix to append before .safetensors (e.g., '_vae', '_lora') | |
| type_prefix: Optional prefix after model_id (e.g., 'model' -> 12345_model.safetensors) | |
| """ | |
| model_id = extract_model_id(url) | |
| # If no CivitAI model ID, try to generate a name from HuggingFace path | |
| if not model_id and "huggingface.co" in url: | |
| # Try to extract name from URL path (e.g., sdxl-vae-fp16-fix -> fp16_fix) | |
| try: | |
| parts = url.split("huggingface.co/")[1] if "huggingface.co/" in url else "" | |
| if parts: | |
| # Get the repo name (second part after org/) | |
| path_parts = [p for p in parts.split("/") if p] | |
| if len(path_parts) >= 2: | |
| repo_name = path_parts[1] | |
| # Clean up and create a simple identifier | |
| clean_repo = re.sub(r'[^a-zA-Z0-9]', '_', repo_name)[:30].strip('_') | |
| if clean_repo: | |
| model_id = f"hf_{clean_repo}" | |
| except Exception: | |
| pass | |
| if not model_id: | |
| return default_name | |
| # Special handling for VAE/LoRA with HuggingFace URLs to avoid double suffix | |
| is_special_type = suffix in ("_vae", "_lora") | |
| # Strip common suffixes from model_id when adding corresponding suffix | |
| # (e.g., "sdxl_vae_fp16_fix" + "_vae" -> "sdxl_fp16_fix" + "_vae") | |
| if is_special_type: | |
| strip_suffix = suffix.lstrip('_') # "vae" or "lora" | |
| model_id_lower = model_id.lower() | |
| # Check if model_id contains the type (with underscore boundaries) | |
| if f"_{strip_suffix}_" in model_id_lower or model_id_lower.endswith(f"_{strip_suffix}"): | |
| # Remove the suffix from model_id | |
| if model_id_lower.endswith(f"_{strip_suffix}"): | |
| model_id = model_id[:-len(strip_suffix)-1] | |
| else: | |
| # Find and remove _suffix_ pattern | |
| pattern = f"_{strip_suffix}_" | |
| idx = model_id_lower.find(pattern) | |
| if idx >= 0: | |
| model_id = model_id[:idx] + model_id[idx+len(pattern):] | |
| # Build the name portion: either clean name from URL or fallback | |
| name_part = "" | |
| # For VAE/LoRA types, skip Content-Disposition parsing to avoid double naming | |
| # (e.g., sdxl_vae_vae instead of just vae) | |
| if not is_special_type: | |
| try: | |
| response = requests.head(url, timeout=10, allow_redirects=True) | |
| cd = response.headers.get('Content-Disposition', '') | |
| match = re.search(r'filename="([^"]+)"', cd) | |
| if match: | |
| filename = match.group(1) | |
| # Extract base name without extension | |
| base_name = Path(filename).stem | |
| # Clean up the name (remove special chars) | |
| clean_name = re.sub(r'[^\w\s-]', '', base_name)[:50] | |
| clean_name = re.sub(r'[-\s]+', '_', clean_name.strip('-_')) | |
| if clean_name: | |
| name_part = clean_name | |
| except Exception: | |
| pass | |
| # Build filename with model_id, optional type_prefix, optional name_part, and suffix | |
| parts = [model_id] | |
| if type_prefix: | |
| parts.append(type_prefix) | |
| if name_part: | |
| parts.append(name_part) | |
| # Handle suffix - for VAE/LoRA we only add the suffix, not double naming | |
| if suffix: | |
| if is_special_type: | |
| # For _vae and _lora: just use model_id + suffix directly | |
| return f"{model_id}{suffix}.safetensors" | |
| else: | |
| # For other types (checkpoint), append suffix after name_part | |
| parts.append(suffix.lstrip('_')) | |
| return '_'.join(p for p in parts if p).replace('__', '_') + '.safetensors' | |
| class TqdmGradio(TqdmBase): | |
| """tqdm subclass that sends progress updates to Gradio's gr.Progress()""" | |
| def __init__(self, *args, gradio_prog=None, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.gradio_prog = gradio_prog | |
| self.last_pct = 0 | |
| def update(self, n=1): | |
| from .config import download_cancelled | |
| if download_cancelled: | |
| raise KeyboardInterrupt("Download cancelled by user") | |
| super().update(n) | |
| if self.gradio_prog and self.total: | |
| pct = int(100 * self.n / self.total) | |
| # Only update UI every ~5% to avoid spamming | |
| if pct != self.last_pct and pct % 5 == 0: | |
| self.last_pct = pct | |
| self.gradio_prog(pct / 100) | |
| def get_cached_file_size(url: str, suffix: str = "", type_prefix: str | None = None) -> tuple[Path | None, int | None]: | |
| """ | |
| Check if file exists in cache and matches expected size. | |
| Uses the same filename generation logic as download operations to find | |
| cached files by URL. | |
| Args: | |
| url: The download URL to check for cached file | |
| suffix: Optional suffix (e.g., '_vae', '_lora') for special file types | |
| type_prefix: Optional prefix after model_id (e.g., 'model') | |
| Returns: | |
| Tuple of (cached_file_path, file_size) if valid cache exists, | |
| or (None, None) if no valid cache found. | |
| """ | |
| from .config import CACHE_DIR | |
| # Generate the expected filename for this URL | |
| default_name = "vae.safetensors" if suffix == "_vae" else ( | |
| "lora.safetensors" if suffix == "_lora" else "model.safetensors" | |
| ) | |
| cached_filename = get_safe_filename_from_url( | |
| url, | |
| default_name=default_name, | |
| suffix=suffix, | |
| type_prefix=type_prefix | |
| ) | |
| cached_path = CACHE_DIR / cached_filename | |
| if cached_path.exists() and cached_path.is_file(): | |
| try: | |
| file_size = cached_path.stat().st_size | |
| # Only return valid cache if file has content | |
| if file_size > 0: | |
| return cached_path, file_size | |
| except OSError: | |
| pass | |
| return None, None | |
| def download_file_with_progress(url: str, output_path: Path, progress_bar=None) -> Path: | |
| """ | |
| Download a file with Gradio-synced progress bar + cancel support. | |
| Checks for existing cached files before downloading. If a valid cache | |
| exists (file exists with matching expected size), skips re-download. | |
| Supports both HTTP(S) and HuggingFace Hub URLs. | |
| Args: | |
| url: File URL to download (http/https/file) | |
| output_path: Destination path for downloaded file | |
| progress_bar: Optional gr.Progress() object for UI updates | |
| Returns: | |
| Path to the downloaded (or cached) file | |
| Raises: | |
| KeyboardInterrupt: If download is cancelled | |
| requests.RequestException: If download fails | |
| """ | |
| from .config import download_cancelled | |
| # Handle local file:// URLs | |
| if url.startswith("file://"): | |
| local_path = Path(url[7:]) # Remove "file://" prefix | |
| if local_path.exists(): | |
| import shutil | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| print(f" 📁 Copying from cache: {local_path.name} → {output_path.name}") | |
| # Copy the file to cache location | |
| shutil.copy2(str(local_path), str(output_path)) | |
| # Update progress bar for cached files | |
| if progress_bar: | |
| progress_bar(1.0) | |
| return output_path | |
| else: | |
| raise FileNotFoundError(f"Local file not found: {local_path}") | |
| print(f" 📥 Downloading to cache: {output_path.name}") | |
| # Early cache check: if file exists and size matches URL's content-length, skip re-download | |
| expected_size = None | |
| try: | |
| head = requests.head(url, timeout=10) | |
| expected_size = int(head.headers.get('content-length', 0)) | |
| except Exception: | |
| pass # Skip header fetch on errors | |
| if output_path.exists() and expected_size is not None: | |
| try: | |
| cached_size = output_path.stat().st_size | |
| if cached_size == expected_size: | |
| print(f" ✅ Cache hit: {output_path.name} ({cached_size / (1024**2):.1f} MB)") | |
| # Cache hit - file exists with correct size | |
| if progress_bar: | |
| progress_bar(1.0) | |
| return output_path # Skip re-download! | |
| except OSError: | |
| pass # File access error, proceed with download | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| session = requests.Session() | |
| response = session.get(url, stream=True, timeout=30) | |
| response.raise_for_status() | |
| total_size = expected_size or int(response.headers.get('content-length', 0)) | |
| block_size = 8192 | |
| # Use TqdmGradio to sync progress with Gradio | |
| tqdm_kwargs = { | |
| 'unit': 'B', | |
| 'unit_scale': True, | |
| 'desc': f"Downloading {output_path.name}", | |
| 'gradio_prog': progress_bar, | |
| 'disable': False, | |
| 'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]', | |
| } | |
| with open(output_path, "wb") as f: | |
| try: | |
| for data in TqdmGradio( | |
| response.iter_content(block_size), | |
| total=total_size // block_size if total_size else 0, | |
| **tqdm_kwargs, | |
| ): | |
| if download_cancelled: | |
| raise KeyboardInterrupt("Download cancelled by user") | |
| f.write(data) | |
| except KeyboardInterrupt: | |
| # Clean partial file on cancel | |
| output_path.unlink(missing_ok=True) | |
| raise | |
| # Verify the downloaded file is complete | |
| try: | |
| actual_size = output_path.stat().st_size | |
| # For safetensors files, check header is valid | |
| if output_path.suffix == ".safetensors": | |
| import struct | |
| with open(output_path, "rb") as f: | |
| header_size_bytes = f.read(8) | |
| if len(header_size_bytes) < 8: | |
| raise OSError(f"Safetensors file too small: {output_path.name}") | |
| header_size = struct.unpack("<Q", header_size_bytes)[0] | |
| header = f.read(header_size) | |
| if len(header) < header_size: | |
| raise OSError(f"Incomplete safetensors header in {output_path.name}") | |
| import json | |
| json.loads(header.decode("utf-8")) | |
| # Verify size matches expected (if known) | |
| if expected_size is not None and actual_size != expected_size: | |
| print(f" ⚠️ Size mismatch: expected {expected_size}, got {actual_size}") | |
| except Exception as e: | |
| output_path.unlink(missing_ok=True) | |
| raise OSError(f"Invalid downloaded file {output_path.name}: {str(e)}") | |
| return output_path | |
| def clear_cache(cache_dir: Path = None, keep_extensions: list[str] = None): | |
| """ | |
| Remove old cache files. | |
| Args: | |
| cache_dir: Cache directory path (defaults to config.CACHE_DIR) | |
| keep_extensions: File extensions to preserve (default: ['.safetensors']) | |
| """ | |
| if cache_dir is None: | |
| from .config import CACHE_DIR | |
| cache_dir = CACHE_DIR | |
| if keep_extensions is None: | |
| keep_extensions = ['.safetensors'] | |
| # Remove temp files | |
| for file in cache_dir.glob("*.tmp"): | |
| file.unlink() | |
| # Optional: age-based cleanup (7 days) | |
| # import time | |
| # cutoff = time.time() - 86400 * 7 | |
| # for f in cache_dir.iterdir(): | |
| # if f.is_file() and f.stat().st_mtime < cutoff: | |
| # f.unlink() | |