"""Configuration constants and global settings for SDXL Model Merger.""" import os import sys from pathlib import Path # ────────────────────────────────────────────── # Paths & Directories # ────────────────────────────────────────────── SCRIPT_DIR = Path.cwd() CACHE_DIR = SCRIPT_DIR / ".cache" CACHE_DIR.mkdir(exist_ok=True) # ────────────────────────────────────────────── # Deployment Environment Detection # ────────────────────────────────────────────── DEPLOYMENT_ENV = os.environ.get("DEPLOYMENT_ENV", "local").lower() if DEPLOYMENT_ENV not in ("local", "spaces"): print(f"⚠️ Unknown DEPLOYMENT_ENV '{DEPLOYMENT_ENV}', defaulting to 'local'") def is_running_on_spaces() -> bool: """Check if running on HuggingFace Spaces.""" return DEPLOYMENT_ENV == "spaces" # ────────────────────────────────────────────── # Default URLs - Use HF models for Spaces compatibility # ────────────────────────────────────────────── DEFAULT_CHECKPOINT_URL = os.environ.get( "DEFAULT_CHECKPOINT_URL", "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors?download=true" ) DEFAULT_VAE_URL = os.environ.get( "DEFAULT_VAE_URL", "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true" ) # Default LoRA - using HF instead of CivitAI DEFAULT_LORA_URLS = os.environ.get( "DEFAULT_LORA_URLS", "https://huggingface.co/nerijs/pixel-art-xl/resolve/main/pixel-art-xl.safetensors?download=true" ) # ────────────────────────────────────────────── # PyTorch & Device Settings # ────────────────────────────────────────────── import torch def get_device_info() -> tuple[str, str]: """ Detect and return the optimal device for ML inference. Returns: Tuple of (device_name, device_description) """ if torch.cuda.is_available(): device_name = "cuda" gpu_name = torch.cuda.get_device_name(0) # Check available VRAM try: vram_total = torch.cuda.get_device_properties(0).total_memory / (1024**3) if vram_total < 8.0: return device_name, f"CUDA (GPU: {gpu_name}, {vram_total:.1f}GB VRAM - low memory)" except Exception: pass return device_name, f"CUDA (GPU: {gpu_name})" else: # CPU fallback - check available RAM try: import psutil ram_gb = psutil.virtual_memory().total / (1024**3) if ram_gb < 16.0: return "cpu", f"CPU (WARNING: {ram_gb:.1f}GB RAM - may be insufficient)" return "cpu", f"CPU ({ram_gb:.1f}GB RAM)" except Exception: return "cpu", "CPU (no GPU available)" device, device_description = get_device_info() dtype = torch.float16 if device == "cuda" else torch.float32 # Check if we're on low-memory hardware and warn def check_memory_requirements() -> bool: """Check if system meets minimum requirements. Returns True if OK.""" min_ram_gb = 8.0 if device == "cpu" else 4.0 try: import psutil total_ram = psutil.virtual_memory().total / (1024**3) # On Spaces with CPU, RAM is limited - use float32 for safety if is_running_on_spaces() and device == "cpu": print(f"ℹ️ Spaces CPU mode detected: using float32 for stability") return True if total_ram < min_ram_gb: print(f"⚠️ Warning: Low memory ({total_ram:.1f}GB < {min_ram_gb}GB required)") return False except Exception: pass return True print(f"🚀 Using device: {device_description}") check_memory_requirements() # ────────────────────────────────────────────── # Global Pipeline State # # IMPORTANT: Use get_pipe() / set_pipe() instead of importing `pipe` directly. # Python's `from .config import pipe` binds the value (None) at import time. # Subsequent set_pipe() calls update the mutable dict, so all modules that # call get_pipe() will always see the current pipeline instance. # ────────────────────────────────────────────── _pipeline_state: dict = {"pipe": None} def get_pipe(): """Get the currently loaded pipeline instance (always up-to-date).""" return _pipeline_state["pipe"] def set_pipe(pipeline) -> None: """Set the globally loaded pipeline instance.""" _pipeline_state["pipe"] = pipeline # Legacy alias kept for any code that references config.pipe directly. # Do NOT use this for checking whether the pipeline is loaded — use get_pipe(). pipe = None # ────────────────────────────────────────────── # Download Cancellation Flag # ────────────────────────────────────────────── download_cancelled = False def set_download_cancelled(value: bool) -> None: """Set the global download cancellation flag.""" global download_cancelled download_cancelled = value # ────────────────────────────────────────────── # Generation Defaults # ────────────────────────────────────────────── DEFAULT_PROMPT = "Glowing mushrooms around pyramids amidst a cosmic backdrop, equirectangular, 360 panorama, cinematic" DEFAULT_NEGATIVE_PROMPT = "boring, text, signature, watermark, low quality, bad quality" # ────────────────────────────────────────────── # Model Presets (URLs for common models) # ────────────────────────────────────────────── MODEL_PRESETS = { # Checkpoints "DreamShaper XL v2": "https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16", "Realism Engine SDXL": "https://civitai.com/api/download/models/328799?type=Model&format=SafeTensor&size=full&fp=fp16", "Juggernaut XL v9": "https://civitai.com/api/download/models/350565?type=Model&format=SafeTensor&size=full&fp=fp16", # VAEs "VAE-FP16 Fix": "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true", # LoRAs "Rainbow Color LoRA": "https://civitai.com/api/download/models/127983?type=Model&format=SafeTensor", "More Details LoRA": "https://civitai.com/api/download/models/280590?type=Model&format=SafeTensor", "Epic Realism LoRA": "https://civitai.com/api/download/models/346631?type=Model&format=SafeTensor", } def get_cached_models(): """Get list of cached model files.""" if not CACHE_DIR.exists(): return [] models = [] for file in sorted(CACHE_DIR.glob("*.safetensors")): models.append(str(file)) return models def get_cached_model_names(): """Get display names for cached models.""" models = get_cached_models() return [str(m.name) for m in models] def get_cached_checkpoints(): """Get list of cached checkpoint files (model_id_model.safetensors).""" if not CACHE_DIR.exists(): return [] models = [] for file in sorted(CACHE_DIR.glob("*_model.safetensors")): models.append(str(file)) return models def get_cached_vaes(): """Get list of cached VAE files (model_id_vae.safetensors or model_id_*_vae.safetensors).""" if not CACHE_DIR.exists(): return [] models = [] # Match both patterns: # - model_id_vae.safetensors # - model_id_name_vae.safetensors (for backward compatibility) for file in sorted(CACHE_DIR.glob("*_vae.safetensors")): models.append(str(file)) return models def get_cached_loras(): """Get list of cached LoRA files (model_id_lora.safetensors or model_id_*_lora.safetensors).""" if not CACHE_DIR.exists(): return [] models = [] # Match both patterns: # - model_id_lora.safetensors # - model_id_name_lora.safetensors (for backward compatibility) for file in sorted(CACHE_DIR.glob("*_lora.safetensors")): models.append(str(file)) return models def validate_cache_file(cache_path: Path, min_size_mb: float = 1.0) -> tuple[bool, str]: """ Validate a cached model file exists and has valid content. Args: cache_path: Path to the cached .safetensors file min_size_mb: Minimum acceptable file size in MB (default: 1MB) Returns: Tuple of (is_valid, message) - is_valid: True if file passes all checks - message: Description of validation result """ try: if not cache_path.exists(): return False, f"File does not exist: {cache_path.name}" if not cache_path.is_file(): return False, f"Not a regular file: {cache_path.name}" file_size = cache_path.stat().st_size size_mb = file_size / (1024 * 1024) if size_mb < min_size_mb: return False, f"File too small ({size_mb:.2f} MB < {min_size_mb} MB): {cache_path.name}" # Check if it's a valid safetensors file by reading the header if not cache_path.suffix == ".safetensors": return True, f"Valid non-safetensors file: {cache_path.name}" try: with open(cache_path, "rb") as f: # Read first 8 bytes (header size) header_size_bytes = f.read(8) if len(header_size_bytes) < 8: return False, f"File too small for safetensors header: {cache_path.name}" import struct header_size = struct.unpack("