Spaces:
Running
Running
| """Checkpoint management with metadata tracking, best-model selection, and pruning. | |
| Provides a central manager for training checkpoints that: | |
| - Tracks per-checkpoint metadata (step, metrics, timestamps) | |
| - Maintains symlinks to best/latest checkpoints | |
| - Prunes old checkpoints to save disk space | |
| - Supports multiple ranking metrics (loss, FID, SSIM, etc.) | |
| Usage: | |
| manager = CheckpointManager( | |
| output_dir="checkpoints/phaseA", | |
| keep_best=3, | |
| keep_latest=5, | |
| metric="loss", | |
| lower_is_better=True, | |
| ) | |
| # During training loop: | |
| manager.save( | |
| step=1000, | |
| controlnet=controlnet, | |
| ema_controlnet=ema_controlnet, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| metrics={"loss": 0.0123, "val_ssim": 0.87}, | |
| ) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import shutil | |
| import time | |
| from dataclasses import asdict, dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| class CheckpointMetadata: | |
| """Metadata for a single checkpoint.""" | |
| step: int | |
| timestamp: float | |
| metrics: dict[str, float] = field(default_factory=dict) | |
| epoch: int | None = None | |
| phase: str = "" | |
| is_best: bool = False | |
| size_mb: float = 0.0 | |
| def to_dict(self) -> dict[str, Any]: | |
| return asdict(self) | |
| def from_dict(cls, d: dict[str, Any]) -> CheckpointMetadata: | |
| return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) | |
| class CheckpointManager: | |
| """Manages training checkpoints with pruning and best-model tracking. | |
| Args: | |
| output_dir: Base directory for checkpoints. | |
| keep_best: Number of best checkpoints to retain. | |
| keep_latest: Number of most recent checkpoints to retain. | |
| metric: Metric name used to determine "best" checkpoint. | |
| lower_is_better: If True, lower metric values are better (e.g. loss, FID). | |
| prefix: Checkpoint directory prefix (default: "checkpoint"). | |
| """ | |
| INDEX_FILE = "checkpoint_index.json" | |
| def __init__( | |
| self, | |
| output_dir: str | Path, | |
| keep_best: int = 3, | |
| keep_latest: int = 5, | |
| metric: str = "loss", | |
| lower_is_better: bool = True, | |
| prefix: str = "checkpoint", | |
| ) -> None: | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| self.keep_best = keep_best | |
| self.keep_latest = keep_latest | |
| self.metric = metric | |
| self.lower_is_better = lower_is_better | |
| self.prefix = prefix | |
| self._index: dict[str, Any] = {"checkpoints": {}} | |
| self._load_index() | |
| # ------------------------------------------------------------------ | |
| # Index persistence | |
| # ------------------------------------------------------------------ | |
| def _index_path(self) -> Path: | |
| return self.output_dir / self.INDEX_FILE | |
| def _load_index(self) -> None: | |
| path = self._index_path() | |
| if path.exists(): | |
| with open(path) as f: | |
| self._index = json.load(f) | |
| if "checkpoints" not in self._index: | |
| self._index["checkpoints"] = {} | |
| # Remove entries whose directories no longer exist on disk | |
| # (can happen after a crash during pruning) | |
| missing = [ | |
| name | |
| for name in list(self._index["checkpoints"]) | |
| if not (self.output_dir / name).exists() | |
| ] | |
| if missing: | |
| for name in missing: | |
| self._index["checkpoints"].pop(name, None) | |
| self._update_best() | |
| self._save_index() | |
| self._update_symlinks() | |
| def _save_index(self) -> None: | |
| with open(self._index_path(), "w") as f: | |
| json.dump(self._index, f, indent=2) | |
| # ------------------------------------------------------------------ | |
| # Save checkpoint | |
| # ------------------------------------------------------------------ | |
| def save( | |
| self, | |
| step: int, | |
| controlnet: torch.nn.Module, | |
| ema_controlnet: torch.nn.Module, | |
| optimizer: torch.optim.Optimizer, | |
| scheduler: Any = None, | |
| metrics: dict[str, float] | None = None, | |
| epoch: int | None = None, | |
| phase: str = "", | |
| extra_state: dict[str, Any] | None = None, | |
| ) -> Path: | |
| """Save a checkpoint with metadata. | |
| Args: | |
| step: Current training step. | |
| controlnet: ControlNet model (or any nn.Module). | |
| ema_controlnet: EMA copy of the model. | |
| optimizer: Optimizer state. | |
| scheduler: Optional LR scheduler. | |
| metrics: Dict of metric values at this step. | |
| epoch: Optional epoch number. | |
| phase: Training phase label (e.g. "A", "B"). | |
| extra_state: Any additional state to save. | |
| Returns: | |
| Path to the saved checkpoint directory. | |
| """ | |
| ckpt_name = f"{self.prefix}-{step}" | |
| ckpt_dir = self.output_dir / ckpt_name | |
| ckpt_dir.mkdir(exist_ok=True) | |
| # Save EMA weights (used for inference) | |
| if hasattr(ema_controlnet, "save_pretrained"): | |
| ema_controlnet.save_pretrained(ckpt_dir / "controlnet_ema") | |
| # Save training state for resume | |
| state = { | |
| "controlnet": _get_state_dict(controlnet), | |
| "ema_controlnet": _get_state_dict(ema_controlnet), | |
| "optimizer": optimizer.state_dict(), | |
| "global_step": step, | |
| } | |
| if scheduler is not None: | |
| state["scheduler"] = scheduler.state_dict() | |
| if extra_state: | |
| state.update(extra_state) | |
| torch.save(state, ckpt_dir / "training_state.pt") | |
| # Compute checkpoint size | |
| size_mb = sum( | |
| f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file() | |
| ) / (1024 * 1024) | |
| # Create metadata | |
| meta = CheckpointMetadata( | |
| step=step, | |
| timestamp=time.time(), | |
| metrics=metrics or {}, | |
| epoch=epoch, | |
| phase=phase, | |
| size_mb=round(size_mb, 1), | |
| ) | |
| # Save metadata alongside checkpoint | |
| with open(ckpt_dir / "metadata.json", "w") as f: | |
| json.dump(meta.to_dict(), f, indent=2) | |
| # Update index | |
| self._index["checkpoints"][ckpt_name] = meta.to_dict() | |
| self._update_best() | |
| self._save_index() | |
| # Update symlinks | |
| self._update_symlinks() | |
| # Prune old checkpoints | |
| self._prune() | |
| return ckpt_dir | |
| # ------------------------------------------------------------------ | |
| # Best / latest tracking | |
| # ------------------------------------------------------------------ | |
| def _update_best(self) -> None: | |
| """Recompute which checkpoints are 'best'.""" | |
| entries = [] | |
| for name, meta in self._index["checkpoints"].items(): | |
| val = meta.get("metrics", {}).get(self.metric) | |
| if val is not None: | |
| entries.append((name, val, meta)) | |
| if not entries: | |
| return | |
| # Sort by metric | |
| entries.sort(key=lambda x: x[1], reverse=not self.lower_is_better) | |
| # Mark best | |
| best_names = {e[0] for e in entries[:self.keep_best]} | |
| for name, meta in self._index["checkpoints"].items(): | |
| meta["is_best"] = name in best_names | |
| def _update_symlinks(self) -> None: | |
| """Update 'latest' and 'best' symlinks.""" | |
| checkpoints = self._sorted_by_step() | |
| if not checkpoints: | |
| return | |
| # Latest symlink | |
| latest_name = checkpoints[-1] | |
| latest_link = self.output_dir / "latest" | |
| _force_symlink(self.output_dir / latest_name, latest_link) | |
| # Best symlink | |
| best_name = self.get_best_checkpoint_name() | |
| if best_name: | |
| best_link = self.output_dir / "best" | |
| _force_symlink(self.output_dir / best_name, best_link) | |
| def get_best_checkpoint_name(self) -> str | None: | |
| """Return the name of the best checkpoint by tracked metric.""" | |
| best = None | |
| best_val = None | |
| for name, meta in self._index["checkpoints"].items(): | |
| val = meta.get("metrics", {}).get(self.metric) | |
| if val is None: | |
| continue | |
| if best_val is None: | |
| best, best_val = name, val | |
| elif self.lower_is_better and val < best_val: | |
| best, best_val = name, val | |
| elif not self.lower_is_better and val > best_val: | |
| best, best_val = name, val | |
| return best | |
| def get_best_metric_value(self) -> float | None: | |
| """Return the best value of the tracked metric.""" | |
| name = self.get_best_checkpoint_name() | |
| if name is None: | |
| return None | |
| return self._index["checkpoints"][name]["metrics"].get(self.metric) | |
| # ------------------------------------------------------------------ | |
| # Pruning | |
| # ------------------------------------------------------------------ | |
| def _sorted_by_step(self) -> list[str]: | |
| """Return checkpoint names sorted by step (ascending).""" | |
| items = list(self._index["checkpoints"].items()) | |
| items.sort(key=lambda x: x[1].get("step", 0)) | |
| return [name for name, _ in items] | |
| def _prune(self) -> None: | |
| """Remove old checkpoints, keeping best N and latest M.""" | |
| all_names = self._sorted_by_step() | |
| if len(all_names) <= self.keep_latest: | |
| return | |
| # Determine which to keep | |
| keep = set() | |
| # Keep latest | |
| for name in all_names[-self.keep_latest:]: | |
| keep.add(name) | |
| # Keep best | |
| for name, meta in self._index["checkpoints"].items(): | |
| if meta.get("is_best", False): | |
| keep.add(name) | |
| # Delete the rest | |
| for name in all_names: | |
| if name not in keep: | |
| ckpt_dir = self.output_dir / name | |
| if ckpt_dir.exists(): | |
| shutil.rmtree(ckpt_dir) | |
| self._index["checkpoints"].pop(name, None) | |
| self._save_index() | |
| # ------------------------------------------------------------------ | |
| # Queries | |
| # ------------------------------------------------------------------ | |
| def list_checkpoints(self) -> list[dict[str, Any]]: | |
| """Return metadata for all tracked checkpoints, sorted by step.""" | |
| result = [] | |
| for name in self._sorted_by_step(): | |
| meta = self._index["checkpoints"][name] | |
| result.append({"name": name, **meta}) | |
| return result | |
| def get_checkpoint_path(self, name: str) -> Path: | |
| """Return the filesystem path for a checkpoint by name.""" | |
| return self.output_dir / name | |
| def get_latest_step(self) -> int: | |
| """Return the step of the most recent checkpoint, or 0.""" | |
| names = self._sorted_by_step() | |
| if not names: | |
| return 0 | |
| return self._index["checkpoints"][names[-1]].get("step", 0) | |
| def total_size_mb(self) -> float: | |
| """Return total disk size of all tracked checkpoints.""" | |
| return sum( | |
| meta.get("size_mb", 0.0) | |
| for meta in self._index["checkpoints"].values() | |
| ) | |
| def summary(self) -> str: | |
| """Return a human-readable summary of checkpoint state.""" | |
| ckpts = self.list_checkpoints() | |
| if not ckpts: | |
| return "No checkpoints saved." | |
| lines = [ | |
| f"Checkpoints: {len(ckpts)} saved ({self.total_size_mb():.0f} MB total)", | |
| f"Latest: step {self.get_latest_step()}", | |
| ] | |
| best_name = self.get_best_checkpoint_name() | |
| best_val = self.get_best_metric_value() | |
| if best_name and best_val is not None: | |
| lines.append(f"Best ({self.metric}): {best_val:.6f} @ {best_name}") | |
| return "\n".join(lines) | |
| # ------------------------------------------------------------------ | |
| # Helpers | |
| # ------------------------------------------------------------------ | |
| def _get_state_dict(module: torch.nn.Module) -> dict: | |
| """Extract state dict, handling DDP wrapper.""" | |
| if hasattr(module, "module"): | |
| return module.module.state_dict() | |
| return module.state_dict() | |
| def _force_symlink(target: Path, link: Path) -> None: | |
| """Create or replace a symlink.""" | |
| if link.is_symlink() or link.exists(): | |
| link.unlink() | |
| link.symlink_to(target.name) | |