dreamlessx's picture
Update landmarkdiff/config.py to v0.3.2
5b7e166 verified
"""YAML-based experiment configuration for reproducible training and evaluation.
Provides typed dataclasses that can be loaded from YAML files, enabling
reproducible experiments with version-tracked configs.
Usage:
from landmarkdiff.config import ExperimentConfig
config = ExperimentConfig.from_yaml("configs/rhinoplasty_phaseA.yaml")
print(config.training.learning_rate)
# Or create programmatically
config = ExperimentConfig(
experiment_name="rhino_v1",
training=TrainingConfig(phase="A", learning_rate=1e-5),
)
config.to_yaml("configs/rhino_v1.yaml")
"""
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any
import yaml
@dataclass
class ModelConfig:
"""ControlNet and base model configuration."""
base_model: str = "runwayml/stable-diffusion-v1-5"
controlnet_conditioning_channels: int = 3
controlnet_conditioning_scale: float = 1.0
use_ema: bool = True
ema_decay: float = 0.9999
gradient_checkpointing: bool = True
@dataclass
class TrainingConfig:
"""Training hyperparameters."""
phase: str = "A" # "A" or "B"
learning_rate: float = 1e-5
batch_size: int = 4
gradient_accumulation_steps: int = 4
max_train_steps: int = 50000
warmup_steps: int = 500
mixed_precision: str = "bf16"
seed: int = 42
ema_decay: float = 0.9999
# Optimizer
optimizer: str = "adamw" # "adamw", "adam8bit", "prodigy"
adam_beta1: float = 0.9
adam_beta2: float = 0.999
weight_decay: float = 1e-2
max_grad_norm: float = 1.0
# LR scheduler
lr_scheduler: str = "cosine"
lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
# Logging intervals
log_every: int = 100
sample_every: int = 1000
# Phase B specific
identity_loss_weight: float = 0.1
perceptual_loss_weight: float = 0.05
use_differentiable_arcface: bool = False
arcface_weights_path: str | None = None
# Loss weights (alternative to individual weights)
loss_weights: dict[str, float] = field(default_factory=dict)
# Checkpointing
save_every_n_steps: int = 5000
resume_from_checkpoint: str | None = None
resume_phase_a: str | None = None
# Validation
validate_every_n_steps: int = 2500
num_validation_samples: int = 4
@dataclass
class DataConfig:
"""Dataset configuration."""
train_dir: str = "data/training_combined"
val_dir: str = "data/splits/val"
test_dir: str = "data/splits/test"
image_size: int = 512
num_workers: int = 4
pin_memory: bool = True
# Augmentation
random_flip: bool = True
random_rotation: float = 5.0 # degrees
color_jitter: float = 0.1
clinical_augment: bool = False
geometric_augment: bool = True
# Procedure filtering
procedures: list[str] = field(
default_factory=lambda: [
"rhinoplasty",
"blepharoplasty",
"rhytidectomy",
"orthognathic",
"brow_lift",
"mentoplasty",
]
)
intensity_range: tuple[float, float] = (30.0, 100.0)
# Data-driven displacement
displacement_model_path: str | None = None
noise_scale: float = 0.1
@dataclass
class InferenceConfig:
"""Inference / generation configuration."""
num_inference_steps: int = 30
guidance_scale: float = 7.5
scheduler: str = "dpmsolver++" # "ddpm", "ddim", "dpmsolver++"
controlnet_conditioning_scale: float = 1.0
# Post-processing
use_neural_postprocess: bool = False
restore_mode: str = "codeformer"
codeformer_fidelity: float = 0.7
use_realesrgan: bool = True
use_laplacian_blend: bool = True
sharpen_strength: float = 0.25
# Identity verification
verify_identity: bool = True
identity_threshold: float = 0.5
@dataclass
class EvaluationConfig:
"""Evaluation configuration."""
compute_fid: bool = True
compute_lpips: bool = True
compute_nme: bool = True
compute_identity: bool = True
compute_ssim: bool = True
stratify_fitzpatrick: bool = True
stratify_procedure: bool = True
max_eval_samples: int = 0 # 0 = all
@dataclass
class WandbConfig:
"""Weights & Biases logging configuration."""
enabled: bool = True
project: str = "landmarkdiff"
entity: str | None = None
run_name: str | None = None
tags: list[str] = field(default_factory=list)
mode: str = "online" # "online", "offline", "disabled"
@dataclass
class SlurmConfig:
"""SLURM job submission parameters."""
partition: str = "batch_gpu"
account: str = "" # Set via YAML or SLURM_ACCOUNT env var
gpu_type: str = "nvidia_rtx_a6000"
num_gpus: int = 1
mem: str = "48G"
cpus_per_task: int = 8
time_limit: str = "48:00:00"
job_prefix: str = "surgery_"
@dataclass
class SafetyConfig:
"""Clinical safety and responsible AI parameters."""
identity_threshold: float = 0.5
max_displacement_fraction: float = 0.05
watermark_enabled: bool = True
watermark_text: str = "AI-GENERATED PREDICTION"
ood_detection_enabled: bool = True
ood_confidence_threshold: float = 0.3
min_face_confidence: float = 0.5
max_yaw_degrees: float = 45.0
@dataclass
class ExperimentConfig:
"""Top-level experiment configuration."""
experiment_name: str = "default"
description: str = ""
version: str = "0.3.2"
model: ModelConfig = field(default_factory=ModelConfig)
training: TrainingConfig = field(default_factory=TrainingConfig)
data: DataConfig = field(default_factory=DataConfig)
inference: InferenceConfig = field(default_factory=InferenceConfig)
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
wandb: WandbConfig = field(default_factory=WandbConfig)
slurm: SlurmConfig = field(default_factory=SlurmConfig)
safety: SafetyConfig = field(default_factory=SafetyConfig)
# Output
output_dir: str = "outputs"
@classmethod
def from_yaml(cls, path: str | Path) -> ExperimentConfig:
"""Load config from a YAML file."""
path = Path(path)
with open(path) as f:
raw = yaml.safe_load(f)
if raw is None:
return cls()
return cls(
experiment_name=raw.get("experiment_name", "default"),
description=raw.get("description", ""),
version=raw.get("version", "0.3.2"),
model=_from_dict(ModelConfig, raw.get("model", {})),
training=_from_dict(TrainingConfig, raw.get("training", {})),
data=_from_dict(DataConfig, raw.get("data", {})),
inference=_from_dict(InferenceConfig, raw.get("inference", {})),
evaluation=_from_dict(EvaluationConfig, raw.get("evaluation", {})),
wandb=_from_dict(WandbConfig, raw.get("wandb", {})),
slurm=_from_dict(SlurmConfig, raw.get("slurm", {})),
safety=_from_dict(SafetyConfig, raw.get("safety", {})),
output_dir=raw.get("output_dir", "outputs"),
)
def to_yaml(self, path: str | Path) -> None:
"""Save config to a YAML file."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
d = _convert_tuples(asdict(self))
with open(path, "w") as f:
yaml.dump(d, f, default_flow_style=False, sort_keys=False)
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
_FIELD_ALIASES: dict[str, str] = {
# YAML name -> dataclass field name
"max_steps": "max_train_steps",
"save_interval": "save_every_n_steps",
"sample_interval": "sample_every",
"log_interval": "log_every",
"adam_weight_decay": "weight_decay",
"lr_warmup_steps": "warmup_steps",
"resume_from": "resume_from_checkpoint",
}
def _from_dict(cls: type, d: dict) -> Any:
"""Create a dataclass from a dict, ignoring unknown keys.
Supports field aliases so YAML configs using train_controlnet.py-style
names (e.g. max_steps) map to dataclass fields (max_train_steps).
"""
import dataclasses
field_map = {f.name: f for f in dataclasses.fields(cls)}
filtered = {}
for k, v in d.items():
# Resolve aliases
canonical = _FIELD_ALIASES.get(k, k)
if canonical not in field_map:
continue
# Don't overwrite if the canonical name was already set explicitly
if canonical in filtered:
continue
# Convert lists back to tuples where the field type is tuple
f = field_map[canonical]
if isinstance(v, list) and "tuple" in str(f.type):
v = tuple(v)
filtered[canonical] = v
return cls(**filtered)
def _convert_tuples(obj: Any) -> Any:
"""Recursively convert tuples to lists for YAML serialization."""
if isinstance(obj, dict):
return {k: _convert_tuples(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_convert_tuples(item) for item in obj]
return obj
def load_config(
config_path: str | Path | None = None,
overrides: dict[str, object] | None = None,
) -> ExperimentConfig:
"""Load config with optional dot-notation overrides.
Args:
config_path: Path to YAML config. None returns defaults.
overrides: Dict of "section.key" -> value overrides.
E.g., {"training.learning_rate": 5e-6}
Returns:
ExperimentConfig with overrides applied.
"""
config = ExperimentConfig.from_yaml(config_path) if config_path else ExperimentConfig()
if overrides:
for key, value in overrides.items():
parts = key.split(".")
obj = config
resolved = True
for part in parts[:-1]:
if hasattr(obj, part):
obj = getattr(obj, part)
else:
resolved = False
break
if resolved and hasattr(obj, parts[-1]):
setattr(obj, parts[-1], value)
return config
def validate_config(config: ExperimentConfig) -> list[str]:
"""Validate config and return list of warnings."""
warnings = []
if config.training.phase == "B" and not config.training.resume_from_checkpoint:
warnings.append("Phase B should resume from a Phase A checkpoint")
eff_batch = config.training.batch_size * config.training.gradient_accumulation_steps
if eff_batch < 8:
warnings.append(f"Effective batch size {eff_batch} < 8 may cause instability")
if config.training.learning_rate > 1e-4:
warnings.append("Learning rate > 1e-4 is unusually high for fine-tuning")
if config.data.image_size != 512:
warnings.append(f"Image size {config.data.image_size} != 512; SD1.5 expects 512")
if config.safety.identity_threshold < 0.3:
warnings.append("Identity threshold < 0.3 may pass poor quality outputs")
return warnings