File size: 11,011 Bytes
134246c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
134246c
 
 
 
 
 
 
 
 
433e26f
134246c
 
 
 
 
 
 
 
 
 
 
433e26f
134246c
 
 
 
 
 
5b7e166
134246c
5b7e166
134246c
 
 
 
 
 
 
 
 
 
 
 
5b7e166
 
 
 
134246c
 
 
 
 
 
5b7e166
 
 
134246c
 
 
5b7e166
134246c
 
 
 
 
 
 
 
 
433e26f
5b7e166
 
 
134246c
 
 
 
 
 
 
 
5b7e166
 
134246c
 
433e26f
 
 
 
 
 
5b7e166
 
433e26f
 
134246c
 
 
 
 
 
 
 
 
 
433e26f
134246c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b7e166
134246c
 
 
 
 
433e26f
134246c
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
134246c
 
 
 
 
5b7e166
134246c
 
 
 
 
433e26f
134246c
5b7e166
134246c
 
 
 
 
 
 
 
 
 
 
433e26f
5b7e166
134246c
 
 
 
 
 
 
 
 
 
 
 
433e26f
134246c
 
5b7e166
134246c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b7e166
134246c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b7e166
 
 
 
 
 
 
 
 
 
 
 
433e26f
5b7e166
 
 
 
 
134246c
433e26f
134246c
 
 
5b7e166
 
 
 
 
 
134246c
 
5b7e166
134246c
 
5b7e166
134246c
 
 
433e26f
134246c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
134246c
 
 
 
 
5b7e166
134246c
 
 
 
5b7e166
134246c
5b7e166
134246c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
"""YAML-based experiment configuration for reproducible training and evaluation.

Provides typed dataclasses that can be loaded from YAML files, enabling
reproducible experiments with version-tracked configs.

Usage:
    from landmarkdiff.config import ExperimentConfig
    config = ExperimentConfig.from_yaml("configs/rhinoplasty_phaseA.yaml")
    print(config.training.learning_rate)

    # Or create programmatically
    config = ExperimentConfig(
        experiment_name="rhino_v1",
        training=TrainingConfig(phase="A", learning_rate=1e-5),
    )
    config.to_yaml("configs/rhino_v1.yaml")
"""

from __future__ import annotations

from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any

import yaml


@dataclass
class ModelConfig:
    """ControlNet and base model configuration."""

    base_model: str = "runwayml/stable-diffusion-v1-5"
    controlnet_conditioning_channels: int = 3
    controlnet_conditioning_scale: float = 1.0
    use_ema: bool = True
    ema_decay: float = 0.9999
    gradient_checkpointing: bool = True


@dataclass
class TrainingConfig:
    """Training hyperparameters."""

    phase: str = "A"  # "A" or "B"
    learning_rate: float = 1e-5
    batch_size: int = 4
    gradient_accumulation_steps: int = 4
    max_train_steps: int = 50000
    warmup_steps: int = 500
    mixed_precision: str = "bf16"
    seed: int = 42
    ema_decay: float = 0.9999

    # Optimizer
    optimizer: str = "adamw"  # "adamw", "adam8bit", "prodigy"
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    weight_decay: float = 1e-2
    max_grad_norm: float = 1.0

    # LR scheduler
    lr_scheduler: str = "cosine"
    lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict)

    # Logging intervals
    log_every: int = 100
    sample_every: int = 1000

    # Phase B specific
    identity_loss_weight: float = 0.1
    perceptual_loss_weight: float = 0.05
    use_differentiable_arcface: bool = False
    arcface_weights_path: str | None = None

    # Loss weights (alternative to individual weights)
    loss_weights: dict[str, float] = field(default_factory=dict)

    # Checkpointing
    save_every_n_steps: int = 5000
    resume_from_checkpoint: str | None = None
    resume_phase_a: str | None = None

    # Validation
    validate_every_n_steps: int = 2500
    num_validation_samples: int = 4


@dataclass
class DataConfig:
    """Dataset configuration."""

    train_dir: str = "data/training_combined"
    val_dir: str = "data/splits/val"
    test_dir: str = "data/splits/test"
    image_size: int = 512
    num_workers: int = 4
    pin_memory: bool = True

    # Augmentation
    random_flip: bool = True
    random_rotation: float = 5.0  # degrees
    color_jitter: float = 0.1
    clinical_augment: bool = False
    geometric_augment: bool = True

    # Procedure filtering
    procedures: list[str] = field(
        default_factory=lambda: [
            "rhinoplasty",
            "blepharoplasty",
            "rhytidectomy",
            "orthognathic",
            "brow_lift",
            "mentoplasty",
        ]
    )
    intensity_range: tuple[float, float] = (30.0, 100.0)

    # Data-driven displacement
    displacement_model_path: str | None = None
    noise_scale: float = 0.1


@dataclass
class InferenceConfig:
    """Inference / generation configuration."""

    num_inference_steps: int = 30
    guidance_scale: float = 7.5
    scheduler: str = "dpmsolver++"  # "ddpm", "ddim", "dpmsolver++"
    controlnet_conditioning_scale: float = 1.0

    # Post-processing
    use_neural_postprocess: bool = False
    restore_mode: str = "codeformer"
    codeformer_fidelity: float = 0.7
    use_realesrgan: bool = True
    use_laplacian_blend: bool = True
    sharpen_strength: float = 0.25

    # Identity verification
    verify_identity: bool = True
    identity_threshold: float = 0.5


@dataclass
class EvaluationConfig:
    """Evaluation configuration."""

    compute_fid: bool = True
    compute_lpips: bool = True
    compute_nme: bool = True
    compute_identity: bool = True
    compute_ssim: bool = True
    stratify_fitzpatrick: bool = True
    stratify_procedure: bool = True
    max_eval_samples: int = 0  # 0 = all


@dataclass
class WandbConfig:
    """Weights & Biases logging configuration."""

    enabled: bool = True
    project: str = "landmarkdiff"
    entity: str | None = None
    run_name: str | None = None
    tags: list[str] = field(default_factory=list)
    mode: str = "online"  # "online", "offline", "disabled"


@dataclass
class SlurmConfig:
    """SLURM job submission parameters."""

    partition: str = "batch_gpu"
    account: str = ""  # Set via YAML or SLURM_ACCOUNT env var
    gpu_type: str = "nvidia_rtx_a6000"
    num_gpus: int = 1
    mem: str = "48G"
    cpus_per_task: int = 8
    time_limit: str = "48:00:00"
    job_prefix: str = "surgery_"


@dataclass
class SafetyConfig:
    """Clinical safety and responsible AI parameters."""

    identity_threshold: float = 0.5
    max_displacement_fraction: float = 0.05
    watermark_enabled: bool = True
    watermark_text: str = "AI-GENERATED PREDICTION"
    ood_detection_enabled: bool = True
    ood_confidence_threshold: float = 0.3
    min_face_confidence: float = 0.5
    max_yaw_degrees: float = 45.0


@dataclass
class ExperimentConfig:
    """Top-level experiment configuration."""

    experiment_name: str = "default"
    description: str = ""
    version: str = "0.3.2"

    model: ModelConfig = field(default_factory=ModelConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    data: DataConfig = field(default_factory=DataConfig)
    inference: InferenceConfig = field(default_factory=InferenceConfig)
    evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
    wandb: WandbConfig = field(default_factory=WandbConfig)
    slurm: SlurmConfig = field(default_factory=SlurmConfig)
    safety: SafetyConfig = field(default_factory=SafetyConfig)

    # Output
    output_dir: str = "outputs"

    @classmethod
    def from_yaml(cls, path: str | Path) -> ExperimentConfig:
        """Load config from a YAML file."""
        path = Path(path)
        with open(path) as f:
            raw = yaml.safe_load(f)

        if raw is None:
            return cls()

        return cls(
            experiment_name=raw.get("experiment_name", "default"),
            description=raw.get("description", ""),
            version=raw.get("version", "0.3.2"),
            model=_from_dict(ModelConfig, raw.get("model", {})),
            training=_from_dict(TrainingConfig, raw.get("training", {})),
            data=_from_dict(DataConfig, raw.get("data", {})),
            inference=_from_dict(InferenceConfig, raw.get("inference", {})),
            evaluation=_from_dict(EvaluationConfig, raw.get("evaluation", {})),
            wandb=_from_dict(WandbConfig, raw.get("wandb", {})),
            slurm=_from_dict(SlurmConfig, raw.get("slurm", {})),
            safety=_from_dict(SafetyConfig, raw.get("safety", {})),
            output_dir=raw.get("output_dir", "outputs"),
        )

    def to_yaml(self, path: str | Path) -> None:
        """Save config to a YAML file."""
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        d = _convert_tuples(asdict(self))
        with open(path, "w") as f:
            yaml.dump(d, f, default_flow_style=False, sort_keys=False)

    def to_dict(self) -> dict:
        """Convert to dictionary."""
        return asdict(self)


_FIELD_ALIASES: dict[str, str] = {
    # YAML name -> dataclass field name
    "max_steps": "max_train_steps",
    "save_interval": "save_every_n_steps",
    "sample_interval": "sample_every",
    "log_interval": "log_every",
    "adam_weight_decay": "weight_decay",
    "lr_warmup_steps": "warmup_steps",
    "resume_from": "resume_from_checkpoint",
}


def _from_dict(cls: type, d: dict) -> Any:
    """Create a dataclass from a dict, ignoring unknown keys.

    Supports field aliases so YAML configs using train_controlnet.py-style
    names (e.g. max_steps) map to dataclass fields (max_train_steps).
    """
    import dataclasses

    field_map = {f.name: f for f in dataclasses.fields(cls)}
    filtered = {}
    for k, v in d.items():
        # Resolve aliases
        canonical = _FIELD_ALIASES.get(k, k)
        if canonical not in field_map:
            continue
        # Don't overwrite if the canonical name was already set explicitly
        if canonical in filtered:
            continue
        # Convert lists back to tuples where the field type is tuple
        f = field_map[canonical]
        if isinstance(v, list) and "tuple" in str(f.type):
            v = tuple(v)
        filtered[canonical] = v
    return cls(**filtered)


def _convert_tuples(obj: Any) -> Any:
    """Recursively convert tuples to lists for YAML serialization."""
    if isinstance(obj, dict):
        return {k: _convert_tuples(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_convert_tuples(item) for item in obj]
    return obj


def load_config(
    config_path: str | Path | None = None,
    overrides: dict[str, object] | None = None,
) -> ExperimentConfig:
    """Load config with optional dot-notation overrides.

    Args:
        config_path: Path to YAML config. None returns defaults.
        overrides: Dict of "section.key" -> value overrides.
            E.g., {"training.learning_rate": 5e-6}

    Returns:
        ExperimentConfig with overrides applied.
    """
    config = ExperimentConfig.from_yaml(config_path) if config_path else ExperimentConfig()

    if overrides:
        for key, value in overrides.items():
            parts = key.split(".")
            obj = config
            resolved = True
            for part in parts[:-1]:
                if hasattr(obj, part):
                    obj = getattr(obj, part)
                else:
                    resolved = False
                    break
            if resolved and hasattr(obj, parts[-1]):
                setattr(obj, parts[-1], value)

    return config


def validate_config(config: ExperimentConfig) -> list[str]:
    """Validate config and return list of warnings."""
    warnings = []

    if config.training.phase == "B" and not config.training.resume_from_checkpoint:
        warnings.append("Phase B should resume from a Phase A checkpoint")

    eff_batch = config.training.batch_size * config.training.gradient_accumulation_steps
    if eff_batch < 8:
        warnings.append(f"Effective batch size {eff_batch} < 8 may cause instability")

    if config.training.learning_rate > 1e-4:
        warnings.append("Learning rate > 1e-4 is unusually high for fine-tuning")

    if config.data.image_size != 512:
        warnings.append(f"Image size {config.data.image_size} != 512; SD1.5 expects 512")

    if config.safety.identity_threshold < 0.3:
        warnings.append("Identity threshold < 0.3 may pass poor quality outputs")

    return warnings