Spaces:
Running
Running
| """Model registry for checkpoint discovery and management. | |
| Provides a unified interface for finding, loading, and comparing model | |
| checkpoints across local directories and remote sources. | |
| Usage: | |
| from landmarkdiff.model_registry import ModelRegistry | |
| registry = ModelRegistry("checkpoints/") | |
| # Discover all checkpoints | |
| models = registry.list_models() | |
| # Get best checkpoint by metric | |
| best = registry.get_best("loss") | |
| # Load a specific checkpoint | |
| state = registry.load("checkpoint-5000") | |
| # Compare multiple checkpoints | |
| comparison = registry.compare(["checkpoint-1000", "checkpoint-5000"]) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| class ModelEntry: | |
| """Metadata for a registered model checkpoint.""" | |
| name: str | |
| path: Path | |
| step: int = 0 | |
| phase: str = "" | |
| metrics: dict[str, float] = field(default_factory=dict) | |
| size_mb: float = 0.0 | |
| has_ema: bool = False | |
| has_training_state: bool = False | |
| def inference_path(self) -> Path | None: | |
| """Path to inference-ready weights (EMA preferred).""" | |
| ema_dir = self.path / "controlnet_ema" | |
| if ema_dir.exists(): | |
| return ema_dir | |
| # Fallback to training state | |
| state_path = self.path / "training_state.pt" | |
| if state_path.exists(): | |
| return state_path | |
| return None | |
| class ModelRegistry: | |
| """Central registry for discovering and managing model checkpoints. | |
| Args: | |
| checkpoint_dirs: One or more directories to scan for checkpoints. | |
| scan_on_init: Whether to scan directories immediately on creation. | |
| """ | |
| def __init__( | |
| self, | |
| *checkpoint_dirs: str | Path, | |
| scan_on_init: bool = True, | |
| ) -> None: | |
| self.checkpoint_dirs = [Path(d) for d in checkpoint_dirs] | |
| self._models: dict[str, ModelEntry] = {} | |
| if scan_on_init: | |
| self.scan() | |
| def scan(self) -> int: | |
| """Scan checkpoint directories and register all found models. | |
| Returns: | |
| Number of models found. | |
| """ | |
| self._models.clear() | |
| for base_dir in self.checkpoint_dirs: | |
| if not base_dir.exists(): | |
| continue | |
| self._scan_directory(base_dir) | |
| return len(self._models) | |
| def _scan_directory(self, base_dir: Path) -> None: | |
| """Scan a single directory for checkpoint subdirectories.""" | |
| # Look for checkpoint-* directories | |
| for ckpt_dir in sorted(base_dir.glob("checkpoint-*")): | |
| if not ckpt_dir.is_dir(): | |
| continue | |
| entry = self._load_entry(ckpt_dir) | |
| if entry is not None: | |
| self._models[entry.name] = entry | |
| # Also check for "final" and "best" directories/symlinks | |
| for special in ["final", "best", "latest"]: | |
| special_dir = base_dir / special | |
| if special_dir.exists() and special_dir.is_dir(): | |
| entry = self._load_entry(special_dir) | |
| if entry is not None: | |
| entry.name = f"{base_dir.name}/{special}" | |
| self._models[entry.name] = entry | |
| def _load_entry(self, ckpt_dir: Path) -> ModelEntry | None: | |
| """Load metadata for a single checkpoint directory.""" | |
| has_training = (ckpt_dir / "training_state.pt").exists() | |
| has_ema = (ckpt_dir / "controlnet_ema").exists() | |
| if not has_training and not has_ema: | |
| return None | |
| # Try to load metadata.json (from CheckpointManager) | |
| meta_path = ckpt_dir / "metadata.json" | |
| if meta_path.exists(): | |
| with open(meta_path) as f: | |
| meta = json.load(f) | |
| return ModelEntry( | |
| name=ckpt_dir.name, | |
| path=ckpt_dir, | |
| step=meta.get("step", 0), | |
| phase=meta.get("phase", ""), | |
| metrics=meta.get("metrics", {}), | |
| size_mb=meta.get("size_mb", 0.0), | |
| has_ema=has_ema, | |
| has_training_state=has_training, | |
| ) | |
| # Fallback: extract step from directory name | |
| step = 0 | |
| parts = ckpt_dir.name.split("-") | |
| if len(parts) >= 2 and parts[-1].isdigit(): | |
| step = int(parts[-1]) | |
| # Compute size | |
| size_mb = sum( | |
| f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file() | |
| ) / (1024 * 1024) | |
| return ModelEntry( | |
| name=ckpt_dir.name, | |
| path=ckpt_dir, | |
| step=step, | |
| size_mb=round(size_mb, 1), | |
| has_ema=has_ema, | |
| has_training_state=has_training, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Queries | |
| # ------------------------------------------------------------------ | |
| def list_models(self, sort_by: str = "step") -> list[ModelEntry]: | |
| """List all registered models. | |
| Args: | |
| sort_by: Sort key — "step", "name", or a metric name. | |
| Returns: | |
| Sorted list of ModelEntry objects. | |
| """ | |
| models = list(self._models.values()) | |
| if sort_by == "step": | |
| models.sort(key=lambda m: m.step) | |
| elif sort_by == "name": | |
| models.sort(key=lambda m: m.name) | |
| else: | |
| # Sort by metric value | |
| models.sort( | |
| key=lambda m: m.metrics.get(sort_by, float("inf")), | |
| ) | |
| return models | |
| def get(self, name: str) -> ModelEntry | None: | |
| """Get a model entry by name.""" | |
| return self._models.get(name) | |
| def get_best( | |
| self, | |
| metric: str = "loss", | |
| lower_is_better: bool = True, | |
| ) -> ModelEntry | None: | |
| """Get the best model by a specific metric. | |
| Args: | |
| metric: Metric name to rank by. | |
| lower_is_better: If True, lower values are better. | |
| Returns: | |
| Best ModelEntry, or None if no models have the metric. | |
| """ | |
| candidates = [ | |
| m for m in self._models.values() | |
| if metric in m.metrics | |
| ] | |
| if not candidates: | |
| return None | |
| return min(candidates, key=lambda m: m.metrics[metric]) \ | |
| if lower_is_better else \ | |
| max(candidates, key=lambda m: m.metrics[metric]) | |
| def get_by_step(self, step: int) -> ModelEntry | None: | |
| """Get a model by its training step.""" | |
| for model in self._models.values(): | |
| if model.step == step: | |
| return model | |
| return None | |
| # ------------------------------------------------------------------ | |
| # Loading | |
| # ------------------------------------------------------------------ | |
| def load( | |
| self, | |
| name: str, | |
| map_location: str = "cpu", | |
| ) -> dict[str, Any]: | |
| """Load training state from a checkpoint. | |
| Args: | |
| name: Checkpoint name (e.g. "checkpoint-5000"). | |
| map_location: Device to load tensors to. | |
| Returns: | |
| State dict containing controlnet, ema_controlnet, optimizer, etc. | |
| Raises: | |
| KeyError: If checkpoint not found. | |
| FileNotFoundError: If training_state.pt missing. | |
| """ | |
| entry = self._models.get(name) | |
| if entry is None: | |
| raise KeyError(f"Checkpoint '{name}' not found in registry") | |
| state_path = entry.path / "training_state.pt" | |
| if not state_path.exists(): | |
| raise FileNotFoundError(f"No training_state.pt in {entry.path}") | |
| return torch.load(state_path, map_location=map_location, weights_only=True) | |
| def load_controlnet( | |
| self, | |
| name: str, | |
| use_ema: bool = True, | |
| torch_dtype: torch.dtype | None = None, | |
| ) -> Any: | |
| """Load a ControlNet model from checkpoint. | |
| Args: | |
| name: Checkpoint name. | |
| use_ema: If True, load EMA weights (preferred for inference). | |
| torch_dtype: Weight dtype (e.g. torch.float16). Defaults to | |
| float16 on CUDA, float32 on CPU. | |
| Returns: | |
| ControlNetModel instance. | |
| """ | |
| from diffusers import ControlNetModel | |
| if torch_dtype is None: | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| entry = self._models.get(name) | |
| if entry is None: | |
| raise KeyError(f"Checkpoint '{name}' not found in registry") | |
| if use_ema and entry.has_ema: | |
| return ControlNetModel.from_pretrained( | |
| str(entry.path / "controlnet_ema"), | |
| torch_dtype=torch_dtype, | |
| ) | |
| # Fallback: load from training state | |
| state = self.load(name) | |
| model = ControlNetModel.from_pretrained( | |
| "CrucibleAI/ControlNetMediaPipeFace", | |
| subfolder="diffusion_sd15", | |
| torch_dtype=torch_dtype, | |
| ) | |
| key = "ema_controlnet" if use_ema else "controlnet" | |
| model.load_state_dict(state[key]) | |
| return model | |
| # ------------------------------------------------------------------ | |
| # Comparison | |
| # ------------------------------------------------------------------ | |
| def compare( | |
| self, | |
| names: list[str], | |
| metrics: list[str] | None = None, | |
| ) -> dict[str, Any]: | |
| """Compare multiple checkpoints side-by-side. | |
| Args: | |
| names: List of checkpoint names to compare. | |
| metrics: Specific metrics to include. None = all available. | |
| Returns: | |
| Dict with comparison table data. | |
| """ | |
| entries = [] | |
| for name in names: | |
| entry = self._models.get(name) | |
| if entry is not None: | |
| entries.append(entry) | |
| if not entries: | |
| return {"error": "No valid checkpoints found"} | |
| # Collect all available metrics | |
| if metrics is None: | |
| all_metrics: set[str] = set() | |
| for e in entries: | |
| all_metrics.update(e.metrics.keys()) | |
| metrics = sorted(all_metrics) | |
| rows = [] | |
| for e in entries: | |
| row = { | |
| "name": e.name, | |
| "step": e.step, | |
| "phase": e.phase, | |
| "size_mb": e.size_mb, | |
| } | |
| for m in metrics: | |
| row[m] = e.metrics.get(m) | |
| rows.append(row) | |
| return { | |
| "metrics": metrics, | |
| "rows": rows, | |
| "count": len(rows), | |
| } | |
| # ------------------------------------------------------------------ | |
| # Summary | |
| # ------------------------------------------------------------------ | |
| def summary(self) -> str: | |
| """Return a human-readable summary.""" | |
| models = self.list_models() | |
| if not models: | |
| return "No models registered." | |
| total_size = sum(m.size_mb for m in models) | |
| lines = [ | |
| f"Model Registry: {len(models)} checkpoints ({total_size:.0f} MB)", | |
| f" Steps: {models[0].step} — {models[-1].step}", | |
| ] | |
| # Show metrics ranges | |
| all_metrics: set[str] = set() | |
| for m in models: | |
| all_metrics.update(m.metrics.keys()) | |
| for metric in sorted(all_metrics): | |
| values = [m.metrics[metric] for m in models if metric in m.metrics] | |
| if values: | |
| lines.append( | |
| f" {metric}: {min(values):.4f} — {max(values):.4f}" | |
| ) | |
| return "\n".join(lines) | |
| def __len__(self) -> int: | |
| return len(self._models) | |
| def __contains__(self, name: str) -> bool: | |
| return name in self._models | |