SDXL-Model-Merger / src /config.py
Kyle Pearson
Add dynamic dtype selection for `get_device_info()` with CUDA/CPU checks, update config logic
df6e003
"""Configuration constants and global settings for SDXL Model Merger."""
import os
import sys
from pathlib import Path
# ──────────────────────────────────────────────
# Paths & Directories
# ──────────────────────────────────────────────
SCRIPT_DIR = Path.cwd()
CACHE_DIR = SCRIPT_DIR / ".cache"
CACHE_DIR.mkdir(exist_ok=True)
# ──────────────────────────────────────────────
# Deployment Environment Detection
# ──────────────────────────────────────────────
DEPLOYMENT_ENV = os.environ.get("DEPLOYMENT_ENV", "local").lower()
if DEPLOYMENT_ENV not in ("local", "spaces"):
print(f"⚠️ Unknown DEPLOYMENT_ENV '{DEPLOYMENT_ENV}', defaulting to 'local'")
def is_running_on_spaces() -> bool:
"""Check if running on HuggingFace Spaces."""
return DEPLOYMENT_ENV == "spaces"
# ──────────────────────────────────────────────
# Default URLs - Use HF models for Spaces compatibility
# ──────────────────────────────────────────────
DEFAULT_CHECKPOINT_URL = os.environ.get(
"DEFAULT_CHECKPOINT_URL",
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors?download=true"
)
DEFAULT_VAE_URL = os.environ.get(
"DEFAULT_VAE_URL",
"https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true"
)
# Default LoRA - using HF instead of CivitAI
DEFAULT_LORA_URLS = os.environ.get(
"DEFAULT_LORA_URLS",
"https://huggingface.co/nerijs/pixel-art-xl/resolve/main/pixel-art-xl.safetensors?download=true"
)
# ──────────────────────────────────────────────
# PyTorch & Device Settings
# ──────────────────────────────────────────────
import torch
def get_device_info() -> tuple[str, str]:
"""
Detect and return the optimal device for ML inference.
Returns:
Tuple of (device_name, device_description)
"""
if torch.cuda.is_available():
device_name = "cuda"
gpu_name = torch.cuda.get_device_name(0)
# Check available VRAM
try:
vram_total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
if vram_total < 8.0:
return device_name, f"CUDA (GPU: {gpu_name}, {vram_total:.1f}GB VRAM - low memory)"
except Exception:
pass
return device_name, f"CUDA (GPU: {gpu_name})"
else:
# CPU fallback - check available RAM
try:
import psutil
ram_gb = psutil.virtual_memory().total / (1024**3)
if ram_gb < 16.0:
return "cpu", f"CPU (WARNING: {ram_gb:.1f}GB RAM - may be insufficient)"
return "cpu", f"CPU ({ram_gb:.1f}GB RAM)"
except Exception:
return "cpu", "CPU (no GPU available)"
device, device_description = get_device_info()
dtype = torch.float16 if device == "cuda" else torch.float32
# Check if we're on low-memory hardware and warn
def check_memory_requirements() -> bool:
"""Check if system meets minimum requirements. Returns True if OK."""
min_ram_gb = 8.0 if device == "cpu" else 4.0
try:
import psutil
total_ram = psutil.virtual_memory().total / (1024**3)
# On Spaces with CPU, RAM is limited - use float32 for safety
if is_running_on_spaces() and device == "cpu":
print(f"ℹ️ Spaces CPU mode detected: using float32 for stability")
return True
if total_ram < min_ram_gb:
print(f"⚠️ Warning: Low memory ({total_ram:.1f}GB < {min_ram_gb}GB required)")
return False
except Exception:
pass
return True
print(f"πŸš€ Using device: {device_description}")
check_memory_requirements()
# ──────────────────────────────────────────────
# Global Pipeline State
#
# IMPORTANT: Use get_pipe() / set_pipe() instead of importing `pipe` directly.
# Python's `from .config import pipe` binds the value (None) at import time.
# Subsequent set_pipe() calls update the mutable dict, so all modules that
# call get_pipe() will always see the current pipeline instance.
# ──────────────────────────────────────────────
_pipeline_state: dict = {"pipe": None}
def get_pipe():
"""Get the currently loaded pipeline instance (always up-to-date)."""
return _pipeline_state["pipe"]
def set_pipe(pipeline) -> None:
"""Set the globally loaded pipeline instance."""
_pipeline_state["pipe"] = pipeline
# Legacy alias kept for any code that references config.pipe directly.
# Do NOT use this for checking whether the pipeline is loaded β€” use get_pipe().
pipe = None
# ──────────────────────────────────────────────
# Download Cancellation Flag
# ──────────────────────────────────────────────
download_cancelled = False
def set_download_cancelled(value: bool) -> None:
"""Set the global download cancellation flag."""
global download_cancelled
download_cancelled = value
# ──────────────────────────────────────────────
# Generation Defaults
# ──────────────────────────────────────────────
DEFAULT_PROMPT = "Glowing mushrooms around pyramids amidst a cosmic backdrop, equirectangular, 360 panorama, cinematic"
DEFAULT_NEGATIVE_PROMPT = "boring, text, signature, watermark, low quality, bad quality"
# ──────────────────────────────────────────────
# Model Presets (URLs for common models)
# ──────────────────────────────────────────────
MODEL_PRESETS = {
# Checkpoints
"DreamShaper XL v2": "https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16",
"Realism Engine SDXL": "https://civitai.com/api/download/models/328799?type=Model&format=SafeTensor&size=full&fp=fp16",
"Juggernaut XL v9": "https://civitai.com/api/download/models/350565?type=Model&format=SafeTensor&size=full&fp=fp16",
# VAEs
"VAE-FP16 Fix": "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true",
# LoRAs
"Rainbow Color LoRA": "https://civitai.com/api/download/models/127983?type=Model&format=SafeTensor",
"More Details LoRA": "https://civitai.com/api/download/models/280590?type=Model&format=SafeTensor",
"Epic Realism LoRA": "https://civitai.com/api/download/models/346631?type=Model&format=SafeTensor",
}
def get_cached_models():
"""Get list of cached model files."""
if not CACHE_DIR.exists():
return []
models = []
for file in sorted(CACHE_DIR.glob("*.safetensors")):
models.append(str(file))
return models
def get_cached_model_names():
"""Get display names for cached models."""
models = get_cached_models()
return [str(m.name) for m in models]
def get_cached_checkpoints():
"""Get list of cached checkpoint files (model_id_model.safetensors)."""
if not CACHE_DIR.exists():
return []
models = []
for file in sorted(CACHE_DIR.glob("*_model.safetensors")):
models.append(str(file))
return models
def get_cached_vaes():
"""Get list of cached VAE files (model_id_vae.safetensors or model_id_*_vae.safetensors)."""
if not CACHE_DIR.exists():
return []
models = []
# Match both patterns:
# - model_id_vae.safetensors
# - model_id_name_vae.safetensors (for backward compatibility)
for file in sorted(CACHE_DIR.glob("*_vae.safetensors")):
models.append(str(file))
return models
def get_cached_loras():
"""Get list of cached LoRA files (model_id_lora.safetensors or model_id_*_lora.safetensors)."""
if not CACHE_DIR.exists():
return []
models = []
# Match both patterns:
# - model_id_lora.safetensors
# - model_id_name_lora.safetensors (for backward compatibility)
for file in sorted(CACHE_DIR.glob("*_lora.safetensors")):
models.append(str(file))
return models
def validate_cache_file(cache_path: Path, min_size_mb: float = 1.0) -> tuple[bool, str]:
"""
Validate a cached model file exists and has valid content.
Args:
cache_path: Path to the cached .safetensors file
min_size_mb: Minimum acceptable file size in MB (default: 1MB)
Returns:
Tuple of (is_valid, message)
- is_valid: True if file passes all checks
- message: Description of validation result
"""
try:
if not cache_path.exists():
return False, f"File does not exist: {cache_path.name}"
if not cache_path.is_file():
return False, f"Not a regular file: {cache_path.name}"
file_size = cache_path.stat().st_size
size_mb = file_size / (1024 * 1024)
if size_mb < min_size_mb:
return False, f"File too small ({size_mb:.2f} MB < {min_size_mb} MB): {cache_path.name}"
# Check if it's a valid safetensors file by reading the header
if not cache_path.suffix == ".safetensors":
return True, f"Valid non-safetensors file: {cache_path.name}"
try:
with open(cache_path, "rb") as f:
# Read first 8 bytes (header size)
header_size_bytes = f.read(8)
if len(header_size_bytes) < 8:
return False, f"File too small for safetensors header: {cache_path.name}"
import struct
header_size = struct.unpack("<Q", header_size_bytes)[0]
if header_size == 0:
return False, f"Invalid safetensors header (size=0): {cache_path.name}"
# Read and parse header JSON
header = f.read(header_size)
if len(header) < header_size:
return False, f"Incomplete safetensors header: {cache_path.name}"
import json
json.loads(header.decode("utf-8"))
except struct.error as e:
return False, f"Invalid safetensors format: {str(e)}"
except json.JSONDecodeError as e:
return False, f"Invalid safetensors header JSON: {str(e)}"
return True, f"Valid cached file ({size_mb:.1f} MB): {cache_path.name}"
except OSError as e:
return False, f"File access error: {str(e)}"