Spaces:
Running on Zero
Running on Zero
File size: 11,572 Bytes
6a07ce1 459ac47 6a07ce1 459ac47 6a07ce1 459ac47 6a07ce1 459ac47 6a07ce1 459ac47 6a07ce1 df6e003 6a07ce1 459ac47 6a07ce1 459ac47 6a07ce1 570384a 6a07ce1 570384a 6a07ce1 570384a 6a07ce1 570384a 6a07ce1 8cdb001 6a07ce1 8cdb001 6a07ce1 8cdb001 6a07ce1 8cdb001 6a07ce1 b1e7bdb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 | """Configuration constants and global settings for SDXL Model Merger."""
import os
import sys
from pathlib import Path
# ββββββββββββββββββββββββββββββββββββββββββββββ
# Paths & Directories
# ββββββββββββββββββββββββββββββββββββββββββββββ
SCRIPT_DIR = Path.cwd()
CACHE_DIR = SCRIPT_DIR / ".cache"
CACHE_DIR.mkdir(exist_ok=True)
# ββββββββββββββββββββββββββββββββββββββββββββββ
# Deployment Environment Detection
# ββββββββββββββββββββββββββββββββββββββββββββββ
DEPLOYMENT_ENV = os.environ.get("DEPLOYMENT_ENV", "local").lower()
if DEPLOYMENT_ENV not in ("local", "spaces"):
print(f"β οΈ Unknown DEPLOYMENT_ENV '{DEPLOYMENT_ENV}', defaulting to 'local'")
def is_running_on_spaces() -> bool:
"""Check if running on HuggingFace Spaces."""
return DEPLOYMENT_ENV == "spaces"
# ββββββββββββββββββββββββββββββββββββββββββββββ
# Default URLs - Use HF models for Spaces compatibility
# ββββββββββββββββββββββββββββββββββββββββββββββ
DEFAULT_CHECKPOINT_URL = os.environ.get(
"DEFAULT_CHECKPOINT_URL",
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors?download=true"
)
DEFAULT_VAE_URL = os.environ.get(
"DEFAULT_VAE_URL",
"https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true"
)
# Default LoRA - using HF instead of CivitAI
DEFAULT_LORA_URLS = os.environ.get(
"DEFAULT_LORA_URLS",
"https://huggingface.co/nerijs/pixel-art-xl/resolve/main/pixel-art-xl.safetensors?download=true"
)
# ββββββββββββββββββββββββββββββββββββββββββββββ
# PyTorch & Device Settings
# ββββββββββββββββββββββββββββββββββββββββββββββ
import torch
def get_device_info() -> tuple[str, str]:
"""
Detect and return the optimal device for ML inference.
Returns:
Tuple of (device_name, device_description)
"""
if torch.cuda.is_available():
device_name = "cuda"
gpu_name = torch.cuda.get_device_name(0)
# Check available VRAM
try:
vram_total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
if vram_total < 8.0:
return device_name, f"CUDA (GPU: {gpu_name}, {vram_total:.1f}GB VRAM - low memory)"
except Exception:
pass
return device_name, f"CUDA (GPU: {gpu_name})"
else:
# CPU fallback - check available RAM
try:
import psutil
ram_gb = psutil.virtual_memory().total / (1024**3)
if ram_gb < 16.0:
return "cpu", f"CPU (WARNING: {ram_gb:.1f}GB RAM - may be insufficient)"
return "cpu", f"CPU ({ram_gb:.1f}GB RAM)"
except Exception:
return "cpu", "CPU (no GPU available)"
device, device_description = get_device_info()
dtype = torch.float16 if device == "cuda" else torch.float32
# Check if we're on low-memory hardware and warn
def check_memory_requirements() -> bool:
"""Check if system meets minimum requirements. Returns True if OK."""
min_ram_gb = 8.0 if device == "cpu" else 4.0
try:
import psutil
total_ram = psutil.virtual_memory().total / (1024**3)
# On Spaces with CPU, RAM is limited - use float32 for safety
if is_running_on_spaces() and device == "cpu":
print(f"βΉοΈ Spaces CPU mode detected: using float32 for stability")
return True
if total_ram < min_ram_gb:
print(f"β οΈ Warning: Low memory ({total_ram:.1f}GB < {min_ram_gb}GB required)")
return False
except Exception:
pass
return True
print(f"π Using device: {device_description}")
check_memory_requirements()
# ββββββββββββββββββββββββββββββββββββββββββββββ
# Global Pipeline State
#
# IMPORTANT: Use get_pipe() / set_pipe() instead of importing `pipe` directly.
# Python's `from .config import pipe` binds the value (None) at import time.
# Subsequent set_pipe() calls update the mutable dict, so all modules that
# call get_pipe() will always see the current pipeline instance.
# ββββββββββββββββββββββββββββββββββββββββββββββ
_pipeline_state: dict = {"pipe": None}
def get_pipe():
"""Get the currently loaded pipeline instance (always up-to-date)."""
return _pipeline_state["pipe"]
def set_pipe(pipeline) -> None:
"""Set the globally loaded pipeline instance."""
_pipeline_state["pipe"] = pipeline
# Legacy alias kept for any code that references config.pipe directly.
# Do NOT use this for checking whether the pipeline is loaded β use get_pipe().
pipe = None
# ββββββββββββββββββββββββββββββββββββββββββββββ
# Download Cancellation Flag
# ββββββββββββββββββββββββββββββββββββββββββββββ
download_cancelled = False
def set_download_cancelled(value: bool) -> None:
"""Set the global download cancellation flag."""
global download_cancelled
download_cancelled = value
# ββββββββββββββββββββββββββββββββββββββββββββββ
# Generation Defaults
# ββββββββββββββββββββββββββββββββββββββββββββββ
DEFAULT_PROMPT = "Glowing mushrooms around pyramids amidst a cosmic backdrop, equirectangular, 360 panorama, cinematic"
DEFAULT_NEGATIVE_PROMPT = "boring, text, signature, watermark, low quality, bad quality"
# ββββββββββββββββββββββββββββββββββββββββββββββ
# Model Presets (URLs for common models)
# ββββββββββββββββββββββββββββββββββββββββββββββ
MODEL_PRESETS = {
# Checkpoints
"DreamShaper XL v2": "https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16",
"Realism Engine SDXL": "https://civitai.com/api/download/models/328799?type=Model&format=SafeTensor&size=full&fp=fp16",
"Juggernaut XL v9": "https://civitai.com/api/download/models/350565?type=Model&format=SafeTensor&size=full&fp=fp16",
# VAEs
"VAE-FP16 Fix": "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true",
# LoRAs
"Rainbow Color LoRA": "https://civitai.com/api/download/models/127983?type=Model&format=SafeTensor",
"More Details LoRA": "https://civitai.com/api/download/models/280590?type=Model&format=SafeTensor",
"Epic Realism LoRA": "https://civitai.com/api/download/models/346631?type=Model&format=SafeTensor",
}
def get_cached_models():
"""Get list of cached model files."""
if not CACHE_DIR.exists():
return []
models = []
for file in sorted(CACHE_DIR.glob("*.safetensors")):
models.append(str(file))
return models
def get_cached_model_names():
"""Get display names for cached models."""
models = get_cached_models()
return [str(m.name) for m in models]
def get_cached_checkpoints():
"""Get list of cached checkpoint files (model_id_model.safetensors)."""
if not CACHE_DIR.exists():
return []
models = []
for file in sorted(CACHE_DIR.glob("*_model.safetensors")):
models.append(str(file))
return models
def get_cached_vaes():
"""Get list of cached VAE files (model_id_vae.safetensors or model_id_*_vae.safetensors)."""
if not CACHE_DIR.exists():
return []
models = []
# Match both patterns:
# - model_id_vae.safetensors
# - model_id_name_vae.safetensors (for backward compatibility)
for file in sorted(CACHE_DIR.glob("*_vae.safetensors")):
models.append(str(file))
return models
def get_cached_loras():
"""Get list of cached LoRA files (model_id_lora.safetensors or model_id_*_lora.safetensors)."""
if not CACHE_DIR.exists():
return []
models = []
# Match both patterns:
# - model_id_lora.safetensors
# - model_id_name_lora.safetensors (for backward compatibility)
for file in sorted(CACHE_DIR.glob("*_lora.safetensors")):
models.append(str(file))
return models
def validate_cache_file(cache_path: Path, min_size_mb: float = 1.0) -> tuple[bool, str]:
"""
Validate a cached model file exists and has valid content.
Args:
cache_path: Path to the cached .safetensors file
min_size_mb: Minimum acceptable file size in MB (default: 1MB)
Returns:
Tuple of (is_valid, message)
- is_valid: True if file passes all checks
- message: Description of validation result
"""
try:
if not cache_path.exists():
return False, f"File does not exist: {cache_path.name}"
if not cache_path.is_file():
return False, f"Not a regular file: {cache_path.name}"
file_size = cache_path.stat().st_size
size_mb = file_size / (1024 * 1024)
if size_mb < min_size_mb:
return False, f"File too small ({size_mb:.2f} MB < {min_size_mb} MB): {cache_path.name}"
# Check if it's a valid safetensors file by reading the header
if not cache_path.suffix == ".safetensors":
return True, f"Valid non-safetensors file: {cache_path.name}"
try:
with open(cache_path, "rb") as f:
# Read first 8 bytes (header size)
header_size_bytes = f.read(8)
if len(header_size_bytes) < 8:
return False, f"File too small for safetensors header: {cache_path.name}"
import struct
header_size = struct.unpack("<Q", header_size_bytes)[0]
if header_size == 0:
return False, f"Invalid safetensors header (size=0): {cache_path.name}"
# Read and parse header JSON
header = f.read(header_size)
if len(header) < header_size:
return False, f"Incomplete safetensors header: {cache_path.name}"
import json
json.loads(header.decode("utf-8"))
except struct.error as e:
return False, f"Invalid safetensors format: {str(e)}"
except json.JSONDecodeError as e:
return False, f"Invalid safetensors header JSON: {str(e)}"
return True, f"Valid cached file ({size_mb:.1f} MB): {cache_path.name}"
except OSError as e:
return False, f"File access error: {str(e)}"
|