Spaces:
Running
Running
| """YAML-based experiment configuration for reproducible training and evaluation. | |
| Provides typed dataclasses that can be loaded from YAML files, enabling | |
| reproducible experiments with version-tracked configs. | |
| Usage: | |
| from landmarkdiff.config import ExperimentConfig | |
| config = ExperimentConfig.from_yaml("configs/rhinoplasty_phaseA.yaml") | |
| print(config.training.learning_rate) | |
| # Or create programmatically | |
| config = ExperimentConfig( | |
| experiment_name="rhino_v1", | |
| training=TrainingConfig(phase="A", learning_rate=1e-5), | |
| ) | |
| config.to_yaml("configs/rhino_v1.yaml") | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import asdict, dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| import yaml | |
| class ModelConfig: | |
| """ControlNet and base model configuration.""" | |
| base_model: str = "runwayml/stable-diffusion-v1-5" | |
| controlnet_conditioning_channels: int = 3 | |
| controlnet_conditioning_scale: float = 1.0 | |
| use_ema: bool = True | |
| ema_decay: float = 0.9999 | |
| gradient_checkpointing: bool = True | |
| class TrainingConfig: | |
| """Training hyperparameters.""" | |
| phase: str = "A" # "A" or "B" | |
| learning_rate: float = 1e-5 | |
| batch_size: int = 4 | |
| gradient_accumulation_steps: int = 4 | |
| max_train_steps: int = 50000 | |
| warmup_steps: int = 500 | |
| mixed_precision: str = "bf16" | |
| seed: int = 42 | |
| ema_decay: float = 0.9999 | |
| # Optimizer | |
| optimizer: str = "adamw" # "adamw", "adam8bit", "prodigy" | |
| adam_beta1: float = 0.9 | |
| adam_beta2: float = 0.999 | |
| weight_decay: float = 1e-2 | |
| max_grad_norm: float = 1.0 | |
| # LR scheduler | |
| lr_scheduler: str = "cosine" | |
| lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict) | |
| # Logging intervals | |
| log_every: int = 100 | |
| sample_every: int = 1000 | |
| # Phase B specific | |
| identity_loss_weight: float = 0.1 | |
| perceptual_loss_weight: float = 0.05 | |
| use_differentiable_arcface: bool = False | |
| arcface_weights_path: str | None = None | |
| # Loss weights (alternative to individual weights) | |
| loss_weights: dict[str, float] = field(default_factory=dict) | |
| # Checkpointing | |
| save_every_n_steps: int = 5000 | |
| resume_from_checkpoint: str | None = None | |
| resume_phase_a: str | None = None | |
| # Validation | |
| validate_every_n_steps: int = 2500 | |
| num_validation_samples: int = 4 | |
| class DataConfig: | |
| """Dataset configuration.""" | |
| train_dir: str = "data/training_combined" | |
| val_dir: str = "data/splits/val" | |
| test_dir: str = "data/splits/test" | |
| image_size: int = 512 | |
| num_workers: int = 4 | |
| pin_memory: bool = True | |
| # Augmentation | |
| random_flip: bool = True | |
| random_rotation: float = 5.0 # degrees | |
| color_jitter: float = 0.1 | |
| clinical_augment: bool = False | |
| geometric_augment: bool = True | |
| # Procedure filtering | |
| procedures: list[str] = field( | |
| default_factory=lambda: [ | |
| "rhinoplasty", | |
| "blepharoplasty", | |
| "rhytidectomy", | |
| "orthognathic", | |
| "brow_lift", | |
| "mentoplasty", | |
| ] | |
| ) | |
| intensity_range: tuple[float, float] = (30.0, 100.0) | |
| # Data-driven displacement | |
| displacement_model_path: str | None = None | |
| noise_scale: float = 0.1 | |
| class InferenceConfig: | |
| """Inference / generation configuration.""" | |
| num_inference_steps: int = 30 | |
| guidance_scale: float = 7.5 | |
| scheduler: str = "dpmsolver++" # "ddpm", "ddim", "dpmsolver++" | |
| controlnet_conditioning_scale: float = 1.0 | |
| # Post-processing | |
| use_neural_postprocess: bool = False | |
| restore_mode: str = "codeformer" | |
| codeformer_fidelity: float = 0.7 | |
| use_realesrgan: bool = True | |
| use_laplacian_blend: bool = True | |
| sharpen_strength: float = 0.25 | |
| # Identity verification | |
| verify_identity: bool = True | |
| identity_threshold: float = 0.5 | |
| class EvaluationConfig: | |
| """Evaluation configuration.""" | |
| compute_fid: bool = True | |
| compute_lpips: bool = True | |
| compute_nme: bool = True | |
| compute_identity: bool = True | |
| compute_ssim: bool = True | |
| stratify_fitzpatrick: bool = True | |
| stratify_procedure: bool = True | |
| max_eval_samples: int = 0 # 0 = all | |
| class WandbConfig: | |
| """Weights & Biases logging configuration.""" | |
| enabled: bool = True | |
| project: str = "landmarkdiff" | |
| entity: str | None = None | |
| run_name: str | None = None | |
| tags: list[str] = field(default_factory=list) | |
| mode: str = "online" # "online", "offline", "disabled" | |
| class SlurmConfig: | |
| """SLURM job submission parameters.""" | |
| partition: str = "batch_gpu" | |
| account: str = "" # Set via YAML or SLURM_ACCOUNT env var | |
| gpu_type: str = "nvidia_rtx_a6000" | |
| num_gpus: int = 1 | |
| mem: str = "48G" | |
| cpus_per_task: int = 8 | |
| time_limit: str = "48:00:00" | |
| job_prefix: str = "surgery_" | |
| class SafetyConfig: | |
| """Clinical safety and responsible AI parameters.""" | |
| identity_threshold: float = 0.5 | |
| max_displacement_fraction: float = 0.05 | |
| watermark_enabled: bool = True | |
| watermark_text: str = "AI-GENERATED PREDICTION" | |
| ood_detection_enabled: bool = True | |
| ood_confidence_threshold: float = 0.3 | |
| min_face_confidence: float = 0.5 | |
| max_yaw_degrees: float = 45.0 | |
| class ExperimentConfig: | |
| """Top-level experiment configuration.""" | |
| experiment_name: str = "default" | |
| description: str = "" | |
| version: str = "0.3.2" | |
| model: ModelConfig = field(default_factory=ModelConfig) | |
| training: TrainingConfig = field(default_factory=TrainingConfig) | |
| data: DataConfig = field(default_factory=DataConfig) | |
| inference: InferenceConfig = field(default_factory=InferenceConfig) | |
| evaluation: EvaluationConfig = field(default_factory=EvaluationConfig) | |
| wandb: WandbConfig = field(default_factory=WandbConfig) | |
| slurm: SlurmConfig = field(default_factory=SlurmConfig) | |
| safety: SafetyConfig = field(default_factory=SafetyConfig) | |
| # Output | |
| output_dir: str = "outputs" | |
| def from_yaml(cls, path: str | Path) -> ExperimentConfig: | |
| """Load config from a YAML file.""" | |
| path = Path(path) | |
| with open(path) as f: | |
| raw = yaml.safe_load(f) | |
| if raw is None: | |
| return cls() | |
| return cls( | |
| experiment_name=raw.get("experiment_name", "default"), | |
| description=raw.get("description", ""), | |
| version=raw.get("version", "0.3.2"), | |
| model=_from_dict(ModelConfig, raw.get("model", {})), | |
| training=_from_dict(TrainingConfig, raw.get("training", {})), | |
| data=_from_dict(DataConfig, raw.get("data", {})), | |
| inference=_from_dict(InferenceConfig, raw.get("inference", {})), | |
| evaluation=_from_dict(EvaluationConfig, raw.get("evaluation", {})), | |
| wandb=_from_dict(WandbConfig, raw.get("wandb", {})), | |
| slurm=_from_dict(SlurmConfig, raw.get("slurm", {})), | |
| safety=_from_dict(SafetyConfig, raw.get("safety", {})), | |
| output_dir=raw.get("output_dir", "outputs"), | |
| ) | |
| def to_yaml(self, path: str | Path) -> None: | |
| """Save config to a YAML file.""" | |
| path = Path(path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| d = _convert_tuples(asdict(self)) | |
| with open(path, "w") as f: | |
| yaml.dump(d, f, default_flow_style=False, sort_keys=False) | |
| def to_dict(self) -> dict: | |
| """Convert to dictionary.""" | |
| return asdict(self) | |
| _FIELD_ALIASES: dict[str, str] = { | |
| # YAML name -> dataclass field name | |
| "max_steps": "max_train_steps", | |
| "save_interval": "save_every_n_steps", | |
| "sample_interval": "sample_every", | |
| "log_interval": "log_every", | |
| "adam_weight_decay": "weight_decay", | |
| "lr_warmup_steps": "warmup_steps", | |
| "resume_from": "resume_from_checkpoint", | |
| } | |
| def _from_dict(cls: type, d: dict) -> Any: | |
| """Create a dataclass from a dict, ignoring unknown keys. | |
| Supports field aliases so YAML configs using train_controlnet.py-style | |
| names (e.g. max_steps) map to dataclass fields (max_train_steps). | |
| """ | |
| import dataclasses | |
| field_map = {f.name: f for f in dataclasses.fields(cls)} | |
| filtered = {} | |
| for k, v in d.items(): | |
| # Resolve aliases | |
| canonical = _FIELD_ALIASES.get(k, k) | |
| if canonical not in field_map: | |
| continue | |
| # Don't overwrite if the canonical name was already set explicitly | |
| if canonical in filtered: | |
| continue | |
| # Convert lists back to tuples where the field type is tuple | |
| f = field_map[canonical] | |
| if isinstance(v, list) and "tuple" in str(f.type): | |
| v = tuple(v) | |
| filtered[canonical] = v | |
| return cls(**filtered) | |
| def _convert_tuples(obj: Any) -> Any: | |
| """Recursively convert tuples to lists for YAML serialization.""" | |
| if isinstance(obj, dict): | |
| return {k: _convert_tuples(v) for k, v in obj.items()} | |
| if isinstance(obj, (list, tuple)): | |
| return [_convert_tuples(item) for item in obj] | |
| return obj | |
| def load_config( | |
| config_path: str | Path | None = None, | |
| overrides: dict[str, object] | None = None, | |
| ) -> ExperimentConfig: | |
| """Load config with optional dot-notation overrides. | |
| Args: | |
| config_path: Path to YAML config. None returns defaults. | |
| overrides: Dict of "section.key" -> value overrides. | |
| E.g., {"training.learning_rate": 5e-6} | |
| Returns: | |
| ExperimentConfig with overrides applied. | |
| """ | |
| config = ExperimentConfig.from_yaml(config_path) if config_path else ExperimentConfig() | |
| if overrides: | |
| for key, value in overrides.items(): | |
| parts = key.split(".") | |
| obj = config | |
| resolved = True | |
| for part in parts[:-1]: | |
| if hasattr(obj, part): | |
| obj = getattr(obj, part) | |
| else: | |
| resolved = False | |
| break | |
| if resolved and hasattr(obj, parts[-1]): | |
| setattr(obj, parts[-1], value) | |
| return config | |
| def validate_config(config: ExperimentConfig) -> list[str]: | |
| """Validate config and return list of warnings.""" | |
| warnings = [] | |
| if config.training.phase == "B" and not config.training.resume_from_checkpoint: | |
| warnings.append("Phase B should resume from a Phase A checkpoint") | |
| eff_batch = config.training.batch_size * config.training.gradient_accumulation_steps | |
| if eff_batch < 8: | |
| warnings.append(f"Effective batch size {eff_batch} < 8 may cause instability") | |
| if config.training.learning_rate > 1e-4: | |
| warnings.append("Learning rate > 1e-4 is unusually high for fine-tuning") | |
| if config.data.image_size != 512: | |
| warnings.append(f"Image size {config.data.image_size} != 512; SD1.5 expects 512") | |
| if config.safety.identity_threshold < 0.3: | |
| warnings.append("Identity threshold < 0.3 may pass poor quality outputs") | |
| return warnings | |