LandmarkDiff / landmarkdiff /experiment_tracker.py
dreamlessx's picture
Update landmarkdiff/experiment_tracker.py to v0.3.2
e3a75a0 verified
"""Local experiment tracker for training reproducibility.
Tracks all training runs with their configs, metrics, and results.
Each experiment gets a unique ID and timestamp.
Usage::
tracker = ExperimentTracker("experiments/")
# Start a new experiment
exp_id = tracker.start(
name="phaseA_v2",
config={
"phase": "A", "lr": 1e-5, "batch": 4,
"steps": 100000, "data": "training_combined",
},
)
# Log metrics during training
tracker.log_metric(exp_id, step=1000, loss=0.045, ssim=0.82)
# Record final results
tracker.finish(exp_id, results={"fid": 42.3, "ssim": 0.87})
# List all experiments
tracker.list_experiments()
# Compare experiments
tracker.compare(["exp_001", "exp_002"])
"""
from __future__ import annotations
import json
import os
import socket
import time
from datetime import datetime
from pathlib import Path
class ExperimentTracker:
"""Simple file-based experiment tracker."""
def __init__(self, experiments_dir: str = "experiments"):
self.dir = Path(experiments_dir)
self.dir.mkdir(parents=True, exist_ok=True)
self._index_path = self.dir / "index.json"
self._index = self._load_index()
def _load_index(self) -> dict:
if self._index_path.exists():
with open(self._index_path) as f:
return json.load(f)
return {"experiments": {}, "counter": 0}
def _save_index(self) -> None:
with open(self._index_path, "w") as f:
json.dump(self._index, f, indent=2)
def start(
self,
name: str,
config: dict,
tags: list[str] | None = None,
) -> str:
"""Start a new experiment. Returns experiment ID."""
self._index["counter"] += 1
exp_id = f"exp_{self._index['counter']:03d}"
exp = {
"id": exp_id,
"name": name,
"config": config,
"tags": tags or [],
"status": "running",
"started_at": datetime.now().isoformat(),
"finished_at": None,
"hostname": socket.gethostname(),
"slurm_job_id": os.environ.get("SLURM_JOB_ID"),
"gpu": os.environ.get("CUDA_VISIBLE_DEVICES"),
"results": {},
"metrics_file": f"{exp_id}_metrics.jsonl",
}
self._index["experiments"][exp_id] = exp
self._save_index()
# Create metrics log file
metrics_path = self.dir / str(exp["metrics_file"])
metrics_path.touch()
print(f"Experiment started: {exp_id} ({name})")
return exp_id
def log_metric(self, exp_id: str, step: int | None = None, **metrics) -> None:
"""Log metrics for a training step."""
exp = self._index["experiments"].get(exp_id)
if not exp:
return
entry = {
"timestamp": time.time(),
"step": step,
**metrics,
}
metrics_path = self.dir / str(exp["metrics_file"])
with open(metrics_path, "a") as f:
f.write(json.dumps(entry) + "\n")
def finish(
self,
exp_id: str,
results: dict | None = None,
status: str = "completed",
) -> None:
"""Mark experiment as finished."""
exp = self._index["experiments"].get(exp_id)
if not exp:
return
exp["status"] = status
exp["finished_at"] = datetime.now().isoformat()
if results:
exp["results"] = results
self._save_index()
print(f"Experiment {exp_id} {status}")
def get_metrics(self, exp_id: str) -> list[dict]:
"""Load all logged metrics for an experiment."""
exp = self._index["experiments"].get(exp_id)
if not exp:
return []
metrics_path = self.dir / str(exp["metrics_file"])
if not metrics_path.exists():
return []
entries = []
with open(metrics_path) as f:
for line in f:
line = line.strip()
if line:
entries.append(json.loads(line))
return entries
def list_experiments(self) -> list[dict]:
"""List all experiments with summary info."""
experiments = []
for exp_id, exp in sorted(self._index["experiments"].items()):
summary = {
"id": exp_id,
"name": exp["name"],
"status": exp["status"],
"started": exp["started_at"][:19],
"tags": exp.get("tags", []),
}
if exp["results"]:
for key in ["fid", "ssim", "lpips", "nme"]:
if key in exp["results"]:
summary[key] = exp["results"][key]
experiments.append(summary)
return experiments
def compare(self, exp_ids: list[str]) -> dict:
"""Compare multiple experiments by their results."""
comparison = {}
for exp_id in exp_ids:
exp = self._index["experiments"].get(exp_id)
if exp:
comparison[exp_id] = {
"name": exp["name"],
"config": exp["config"],
"results": exp["results"],
}
return comparison
def print_summary(self) -> None:
"""Print a summary table of all experiments."""
experiments = self.list_experiments()
if not experiments:
print("No experiments found.")
return
# Header
print(f"{'ID':<10} {'Name':<20} {'Status':<12} {'FID':>6} {'SSIM':>6} {'LPIPS':>6}")
print("-" * 70)
for exp in experiments:
fid = f"{exp.get('fid', '')}" if "fid" in exp else "--"
ssim = f"{exp.get('ssim', ''):.4f}" if "ssim" in exp else "--"
lpips = f"{exp.get('lpips', ''):.4f}" if "lpips" in exp else "--"
print(f"{exp['id']:<10} {exp['name']:<20} {exp['status']:<12} {fid:>6} {ssim:>6} {lpips:>6}")
def get_best(self, metric: str = "fid", lower_is_better: bool = True) -> str | None:
"""Get the experiment ID with the best value for a given metric."""
best_id = None
best_val = float("inf") if lower_is_better else float("-inf")
for exp_id, exp in self._index["experiments"].items():
if exp["status"] != "completed":
continue
val = exp["results"].get(metric)
if val is None:
continue
if (lower_is_better and val < best_val) or (not lower_is_better and val > best_val):
best_val = val
best_id = exp_id
return best_id