Spaces:
Running on Zero
Running on Zero
Kyle Pearson
Add dynamic dtype selection for `get_device_info()` with CUDA/CPU checks, update config logic
df6e003 | """Configuration constants and global settings for SDXL Model Merger.""" | |
| import os | |
| import sys | |
| from pathlib import Path | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Paths & Directories | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| SCRIPT_DIR = Path.cwd() | |
| CACHE_DIR = SCRIPT_DIR / ".cache" | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Deployment Environment Detection | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEPLOYMENT_ENV = os.environ.get("DEPLOYMENT_ENV", "local").lower() | |
| if DEPLOYMENT_ENV not in ("local", "spaces"): | |
| print(f"β οΈ Unknown DEPLOYMENT_ENV '{DEPLOYMENT_ENV}', defaulting to 'local'") | |
| def is_running_on_spaces() -> bool: | |
| """Check if running on HuggingFace Spaces.""" | |
| return DEPLOYMENT_ENV == "spaces" | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Default URLs - Use HF models for Spaces compatibility | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEFAULT_CHECKPOINT_URL = os.environ.get( | |
| "DEFAULT_CHECKPOINT_URL", | |
| "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors?download=true" | |
| ) | |
| DEFAULT_VAE_URL = os.environ.get( | |
| "DEFAULT_VAE_URL", | |
| "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true" | |
| ) | |
| # Default LoRA - using HF instead of CivitAI | |
| DEFAULT_LORA_URLS = os.environ.get( | |
| "DEFAULT_LORA_URLS", | |
| "https://huggingface.co/nerijs/pixel-art-xl/resolve/main/pixel-art-xl.safetensors?download=true" | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PyTorch & Device Settings | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| import torch | |
| def get_device_info() -> tuple[str, str]: | |
| """ | |
| Detect and return the optimal device for ML inference. | |
| Returns: | |
| Tuple of (device_name, device_description) | |
| """ | |
| if torch.cuda.is_available(): | |
| device_name = "cuda" | |
| gpu_name = torch.cuda.get_device_name(0) | |
| # Check available VRAM | |
| try: | |
| vram_total = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| if vram_total < 8.0: | |
| return device_name, f"CUDA (GPU: {gpu_name}, {vram_total:.1f}GB VRAM - low memory)" | |
| except Exception: | |
| pass | |
| return device_name, f"CUDA (GPU: {gpu_name})" | |
| else: | |
| # CPU fallback - check available RAM | |
| try: | |
| import psutil | |
| ram_gb = psutil.virtual_memory().total / (1024**3) | |
| if ram_gb < 16.0: | |
| return "cpu", f"CPU (WARNING: {ram_gb:.1f}GB RAM - may be insufficient)" | |
| return "cpu", f"CPU ({ram_gb:.1f}GB RAM)" | |
| except Exception: | |
| return "cpu", "CPU (no GPU available)" | |
| device, device_description = get_device_info() | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| # Check if we're on low-memory hardware and warn | |
| def check_memory_requirements() -> bool: | |
| """Check if system meets minimum requirements. Returns True if OK.""" | |
| min_ram_gb = 8.0 if device == "cpu" else 4.0 | |
| try: | |
| import psutil | |
| total_ram = psutil.virtual_memory().total / (1024**3) | |
| # On Spaces with CPU, RAM is limited - use float32 for safety | |
| if is_running_on_spaces() and device == "cpu": | |
| print(f"βΉοΈ Spaces CPU mode detected: using float32 for stability") | |
| return True | |
| if total_ram < min_ram_gb: | |
| print(f"β οΈ Warning: Low memory ({total_ram:.1f}GB < {min_ram_gb}GB required)") | |
| return False | |
| except Exception: | |
| pass | |
| return True | |
| print(f"π Using device: {device_description}") | |
| check_memory_requirements() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Global Pipeline State | |
| # | |
| # IMPORTANT: Use get_pipe() / set_pipe() instead of importing `pipe` directly. | |
| # Python's `from .config import pipe` binds the value (None) at import time. | |
| # Subsequent set_pipe() calls update the mutable dict, so all modules that | |
| # call get_pipe() will always see the current pipeline instance. | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| _pipeline_state: dict = {"pipe": None} | |
| def get_pipe(): | |
| """Get the currently loaded pipeline instance (always up-to-date).""" | |
| return _pipeline_state["pipe"] | |
| def set_pipe(pipeline) -> None: | |
| """Set the globally loaded pipeline instance.""" | |
| _pipeline_state["pipe"] = pipeline | |
| # Legacy alias kept for any code that references config.pipe directly. | |
| # Do NOT use this for checking whether the pipeline is loaded β use get_pipe(). | |
| pipe = None | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Download Cancellation Flag | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| download_cancelled = False | |
| def set_download_cancelled(value: bool) -> None: | |
| """Set the global download cancellation flag.""" | |
| global download_cancelled | |
| download_cancelled = value | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Generation Defaults | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEFAULT_PROMPT = "Glowing mushrooms around pyramids amidst a cosmic backdrop, equirectangular, 360 panorama, cinematic" | |
| DEFAULT_NEGATIVE_PROMPT = "boring, text, signature, watermark, low quality, bad quality" | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model Presets (URLs for common models) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_PRESETS = { | |
| # Checkpoints | |
| "DreamShaper XL v2": "https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16", | |
| "Realism Engine SDXL": "https://civitai.com/api/download/models/328799?type=Model&format=SafeTensor&size=full&fp=fp16", | |
| "Juggernaut XL v9": "https://civitai.com/api/download/models/350565?type=Model&format=SafeTensor&size=full&fp=fp16", | |
| # VAEs | |
| "VAE-FP16 Fix": "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true", | |
| # LoRAs | |
| "Rainbow Color LoRA": "https://civitai.com/api/download/models/127983?type=Model&format=SafeTensor", | |
| "More Details LoRA": "https://civitai.com/api/download/models/280590?type=Model&format=SafeTensor", | |
| "Epic Realism LoRA": "https://civitai.com/api/download/models/346631?type=Model&format=SafeTensor", | |
| } | |
| def get_cached_models(): | |
| """Get list of cached model files.""" | |
| if not CACHE_DIR.exists(): | |
| return [] | |
| models = [] | |
| for file in sorted(CACHE_DIR.glob("*.safetensors")): | |
| models.append(str(file)) | |
| return models | |
| def get_cached_model_names(): | |
| """Get display names for cached models.""" | |
| models = get_cached_models() | |
| return [str(m.name) for m in models] | |
| def get_cached_checkpoints(): | |
| """Get list of cached checkpoint files (model_id_model.safetensors).""" | |
| if not CACHE_DIR.exists(): | |
| return [] | |
| models = [] | |
| for file in sorted(CACHE_DIR.glob("*_model.safetensors")): | |
| models.append(str(file)) | |
| return models | |
| def get_cached_vaes(): | |
| """Get list of cached VAE files (model_id_vae.safetensors or model_id_*_vae.safetensors).""" | |
| if not CACHE_DIR.exists(): | |
| return [] | |
| models = [] | |
| # Match both patterns: | |
| # - model_id_vae.safetensors | |
| # - model_id_name_vae.safetensors (for backward compatibility) | |
| for file in sorted(CACHE_DIR.glob("*_vae.safetensors")): | |
| models.append(str(file)) | |
| return models | |
| def get_cached_loras(): | |
| """Get list of cached LoRA files (model_id_lora.safetensors or model_id_*_lora.safetensors).""" | |
| if not CACHE_DIR.exists(): | |
| return [] | |
| models = [] | |
| # Match both patterns: | |
| # - model_id_lora.safetensors | |
| # - model_id_name_lora.safetensors (for backward compatibility) | |
| for file in sorted(CACHE_DIR.glob("*_lora.safetensors")): | |
| models.append(str(file)) | |
| return models | |
| def validate_cache_file(cache_path: Path, min_size_mb: float = 1.0) -> tuple[bool, str]: | |
| """ | |
| Validate a cached model file exists and has valid content. | |
| Args: | |
| cache_path: Path to the cached .safetensors file | |
| min_size_mb: Minimum acceptable file size in MB (default: 1MB) | |
| Returns: | |
| Tuple of (is_valid, message) | |
| - is_valid: True if file passes all checks | |
| - message: Description of validation result | |
| """ | |
| try: | |
| if not cache_path.exists(): | |
| return False, f"File does not exist: {cache_path.name}" | |
| if not cache_path.is_file(): | |
| return False, f"Not a regular file: {cache_path.name}" | |
| file_size = cache_path.stat().st_size | |
| size_mb = file_size / (1024 * 1024) | |
| if size_mb < min_size_mb: | |
| return False, f"File too small ({size_mb:.2f} MB < {min_size_mb} MB): {cache_path.name}" | |
| # Check if it's a valid safetensors file by reading the header | |
| if not cache_path.suffix == ".safetensors": | |
| return True, f"Valid non-safetensors file: {cache_path.name}" | |
| try: | |
| with open(cache_path, "rb") as f: | |
| # Read first 8 bytes (header size) | |
| header_size_bytes = f.read(8) | |
| if len(header_size_bytes) < 8: | |
| return False, f"File too small for safetensors header: {cache_path.name}" | |
| import struct | |
| header_size = struct.unpack("<Q", header_size_bytes)[0] | |
| if header_size == 0: | |
| return False, f"Invalid safetensors header (size=0): {cache_path.name}" | |
| # Read and parse header JSON | |
| header = f.read(header_size) | |
| if len(header) < header_size: | |
| return False, f"Incomplete safetensors header: {cache_path.name}" | |
| import json | |
| json.loads(header.decode("utf-8")) | |
| except struct.error as e: | |
| return False, f"Invalid safetensors format: {str(e)}" | |
| except json.JSONDecodeError as e: | |
| return False, f"Invalid safetensors header JSON: {str(e)}" | |
| return True, f"Valid cached file ({size_mb:.1f} MB): {cache_path.name}" | |
| except OSError as e: | |
| return False, f"File access error: {str(e)}" | |