LandmarkDiff / landmarkdiff /model_registry.py
dreamlessx's picture
Update landmarkdiff/model_registry.py to v0.3.2
30cc2b8 verified
"""Model registry for checkpoint discovery and management.
Provides a unified interface for finding, loading, and comparing model
checkpoints across local directories and remote sources.
Usage:
from landmarkdiff.model_registry import ModelRegistry
registry = ModelRegistry("checkpoints/")
# Discover all checkpoints
models = registry.list_models()
# Get best checkpoint by metric
best = registry.get_best("loss")
# Load a specific checkpoint
state = registry.load("checkpoint-5000")
# Compare multiple checkpoints
comparison = registry.compare(["checkpoint-1000", "checkpoint-5000"])
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import torch
@dataclass
class ModelEntry:
"""Metadata for a registered model checkpoint."""
name: str
path: Path
step: int = 0
phase: str = ""
metrics: dict[str, float] = field(default_factory=dict)
size_mb: float = 0.0
has_ema: bool = False
has_training_state: bool = False
@property
def inference_path(self) -> Path | None:
"""Path to inference-ready weights (EMA preferred)."""
ema_dir = self.path / "controlnet_ema"
if ema_dir.exists():
return ema_dir
# Fallback to training state
state_path = self.path / "training_state.pt"
if state_path.exists():
return state_path
return None
class ModelRegistry:
"""Central registry for discovering and managing model checkpoints.
Args:
checkpoint_dirs: One or more directories to scan for checkpoints.
scan_on_init: Whether to scan directories immediately on creation.
"""
def __init__(
self,
*checkpoint_dirs: str | Path,
scan_on_init: bool = True,
) -> None:
self.checkpoint_dirs = [Path(d) for d in checkpoint_dirs]
self._models: dict[str, ModelEntry] = {}
if scan_on_init:
self.scan()
def scan(self) -> int:
"""Scan checkpoint directories and register all found models.
Returns:
Number of models found.
"""
self._models.clear()
for base_dir in self.checkpoint_dirs:
if not base_dir.exists():
continue
self._scan_directory(base_dir)
return len(self._models)
def _scan_directory(self, base_dir: Path) -> None:
"""Scan a single directory for checkpoint subdirectories."""
# Look for checkpoint-* directories
for ckpt_dir in sorted(base_dir.glob("checkpoint-*")):
if not ckpt_dir.is_dir():
continue
entry = self._load_entry(ckpt_dir)
if entry is not None:
self._models[entry.name] = entry
# Also check for "final" and "best" directories/symlinks
for special in ["final", "best", "latest"]:
special_dir = base_dir / special
if special_dir.exists() and special_dir.is_dir():
entry = self._load_entry(special_dir)
if entry is not None:
entry.name = f"{base_dir.name}/{special}"
self._models[entry.name] = entry
def _load_entry(self, ckpt_dir: Path) -> ModelEntry | None:
"""Load metadata for a single checkpoint directory."""
has_training = (ckpt_dir / "training_state.pt").exists()
has_ema = (ckpt_dir / "controlnet_ema").exists()
if not has_training and not has_ema:
return None
# Try to load metadata.json (from CheckpointManager)
meta_path = ckpt_dir / "metadata.json"
if meta_path.exists():
with open(meta_path) as f:
meta = json.load(f)
return ModelEntry(
name=ckpt_dir.name,
path=ckpt_dir,
step=meta.get("step", 0),
phase=meta.get("phase", ""),
metrics=meta.get("metrics", {}),
size_mb=meta.get("size_mb", 0.0),
has_ema=has_ema,
has_training_state=has_training,
)
# Fallback: extract step from directory name
step = 0
parts = ckpt_dir.name.split("-")
if len(parts) >= 2 and parts[-1].isdigit():
step = int(parts[-1])
# Compute size
size_mb = sum(
f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()
) / (1024 * 1024)
return ModelEntry(
name=ckpt_dir.name,
path=ckpt_dir,
step=step,
size_mb=round(size_mb, 1),
has_ema=has_ema,
has_training_state=has_training,
)
# ------------------------------------------------------------------
# Queries
# ------------------------------------------------------------------
def list_models(self, sort_by: str = "step") -> list[ModelEntry]:
"""List all registered models.
Args:
sort_by: Sort key — "step", "name", or a metric name.
Returns:
Sorted list of ModelEntry objects.
"""
models = list(self._models.values())
if sort_by == "step":
models.sort(key=lambda m: m.step)
elif sort_by == "name":
models.sort(key=lambda m: m.name)
else:
# Sort by metric value
models.sort(
key=lambda m: m.metrics.get(sort_by, float("inf")),
)
return models
def get(self, name: str) -> ModelEntry | None:
"""Get a model entry by name."""
return self._models.get(name)
def get_best(
self,
metric: str = "loss",
lower_is_better: bool = True,
) -> ModelEntry | None:
"""Get the best model by a specific metric.
Args:
metric: Metric name to rank by.
lower_is_better: If True, lower values are better.
Returns:
Best ModelEntry, or None if no models have the metric.
"""
candidates = [
m for m in self._models.values()
if metric in m.metrics
]
if not candidates:
return None
return min(candidates, key=lambda m: m.metrics[metric]) \
if lower_is_better else \
max(candidates, key=lambda m: m.metrics[metric])
def get_by_step(self, step: int) -> ModelEntry | None:
"""Get a model by its training step."""
for model in self._models.values():
if model.step == step:
return model
return None
# ------------------------------------------------------------------
# Loading
# ------------------------------------------------------------------
def load(
self,
name: str,
map_location: str = "cpu",
) -> dict[str, Any]:
"""Load training state from a checkpoint.
Args:
name: Checkpoint name (e.g. "checkpoint-5000").
map_location: Device to load tensors to.
Returns:
State dict containing controlnet, ema_controlnet, optimizer, etc.
Raises:
KeyError: If checkpoint not found.
FileNotFoundError: If training_state.pt missing.
"""
entry = self._models.get(name)
if entry is None:
raise KeyError(f"Checkpoint '{name}' not found in registry")
state_path = entry.path / "training_state.pt"
if not state_path.exists():
raise FileNotFoundError(f"No training_state.pt in {entry.path}")
return torch.load(state_path, map_location=map_location, weights_only=True)
def load_controlnet(
self,
name: str,
use_ema: bool = True,
torch_dtype: torch.dtype | None = None,
) -> Any:
"""Load a ControlNet model from checkpoint.
Args:
name: Checkpoint name.
use_ema: If True, load EMA weights (preferred for inference).
torch_dtype: Weight dtype (e.g. torch.float16). Defaults to
float16 on CUDA, float32 on CPU.
Returns:
ControlNetModel instance.
"""
from diffusers import ControlNetModel
if torch_dtype is None:
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
entry = self._models.get(name)
if entry is None:
raise KeyError(f"Checkpoint '{name}' not found in registry")
if use_ema and entry.has_ema:
return ControlNetModel.from_pretrained(
str(entry.path / "controlnet_ema"),
torch_dtype=torch_dtype,
)
# Fallback: load from training state
state = self.load(name)
model = ControlNetModel.from_pretrained(
"CrucibleAI/ControlNetMediaPipeFace",
subfolder="diffusion_sd15",
torch_dtype=torch_dtype,
)
key = "ema_controlnet" if use_ema else "controlnet"
model.load_state_dict(state[key])
return model
# ------------------------------------------------------------------
# Comparison
# ------------------------------------------------------------------
def compare(
self,
names: list[str],
metrics: list[str] | None = None,
) -> dict[str, Any]:
"""Compare multiple checkpoints side-by-side.
Args:
names: List of checkpoint names to compare.
metrics: Specific metrics to include. None = all available.
Returns:
Dict with comparison table data.
"""
entries = []
for name in names:
entry = self._models.get(name)
if entry is not None:
entries.append(entry)
if not entries:
return {"error": "No valid checkpoints found"}
# Collect all available metrics
if metrics is None:
all_metrics: set[str] = set()
for e in entries:
all_metrics.update(e.metrics.keys())
metrics = sorted(all_metrics)
rows = []
for e in entries:
row = {
"name": e.name,
"step": e.step,
"phase": e.phase,
"size_mb": e.size_mb,
}
for m in metrics:
row[m] = e.metrics.get(m)
rows.append(row)
return {
"metrics": metrics,
"rows": rows,
"count": len(rows),
}
# ------------------------------------------------------------------
# Summary
# ------------------------------------------------------------------
def summary(self) -> str:
"""Return a human-readable summary."""
models = self.list_models()
if not models:
return "No models registered."
total_size = sum(m.size_mb for m in models)
lines = [
f"Model Registry: {len(models)} checkpoints ({total_size:.0f} MB)",
f" Steps: {models[0].step}{models[-1].step}",
]
# Show metrics ranges
all_metrics: set[str] = set()
for m in models:
all_metrics.update(m.metrics.keys())
for metric in sorted(all_metrics):
values = [m.metrics[metric] for m in models if metric in m.metrics]
if values:
lines.append(
f" {metric}: {min(values):.4f}{max(values):.4f}"
)
return "\n".join(lines)
def __len__(self) -> int:
return len(self._models)
def __contains__(self, name: str) -> bool:
return name in self._models