dreamlessx commited on
Commit
134246c
·
verified ·
1 Parent(s): da3a0ac

Upload landmarkdiff/config.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/config.py +309 -0
landmarkdiff/config.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """YAML-based experiment configuration for reproducible training and evaluation.
2
+
3
+ Provides typed dataclasses that can be loaded from YAML files, enabling
4
+ reproducible experiments with version-tracked configs.
5
+
6
+ Usage:
7
+ from landmarkdiff.config import ExperimentConfig
8
+ config = ExperimentConfig.from_yaml("configs/rhinoplasty_phaseA.yaml")
9
+ print(config.training.learning_rate)
10
+
11
+ # Or create programmatically
12
+ config = ExperimentConfig(
13
+ experiment_name="rhino_v1",
14
+ training=TrainingConfig(phase="A", learning_rate=1e-5),
15
+ )
16
+ config.to_yaml("configs/rhino_v1.yaml")
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from dataclasses import dataclass, field, asdict
22
+ from pathlib import Path
23
+ from typing import Any
24
+
25
+ import yaml
26
+
27
+
28
+ @dataclass
29
+ class ModelConfig:
30
+ """ControlNet and base model configuration."""
31
+ base_model: str = "runwayml/stable-diffusion-v1-5"
32
+ controlnet_conditioning_channels: int = 3
33
+ controlnet_conditioning_scale: float = 1.0
34
+ use_ema: bool = True
35
+ ema_decay: float = 0.9999
36
+ gradient_checkpointing: bool = True
37
+
38
+
39
+ @dataclass
40
+ class TrainingConfig:
41
+ """Training hyperparameters."""
42
+ phase: str = "A" # "A" or "B"
43
+ learning_rate: float = 1e-5
44
+ batch_size: int = 4
45
+ gradient_accumulation_steps: int = 4
46
+ max_train_steps: int = 50000
47
+ warmup_steps: int = 500
48
+ mixed_precision: str = "fp16"
49
+ seed: int = 42
50
+
51
+ # Optimizer
52
+ optimizer: str = "adamw" # "adamw", "adam8bit", "prodigy"
53
+ adam_beta1: float = 0.9
54
+ adam_beta2: float = 0.999
55
+ weight_decay: float = 1e-2
56
+ max_grad_norm: float = 1.0
57
+
58
+ # LR scheduler
59
+ lr_scheduler: str = "cosine"
60
+ lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
61
+
62
+ # Phase B specific
63
+ identity_loss_weight: float = 0.1
64
+ perceptual_loss_weight: float = 0.05
65
+ use_differentiable_arcface: bool = False
66
+ arcface_weights_path: str | None = None
67
+
68
+ # Checkpointing
69
+ save_every_n_steps: int = 5000
70
+ resume_from_checkpoint: str | None = None
71
+
72
+ # Validation
73
+ validate_every_n_steps: int = 2500
74
+ num_validation_samples: int = 4
75
+
76
+
77
+ @dataclass
78
+ class DataConfig:
79
+ """Dataset configuration."""
80
+ train_dir: str = "data/training"
81
+ val_dir: str = "data/validation"
82
+ test_dir: str = "data/test"
83
+ image_size: int = 512
84
+ num_workers: int = 4
85
+ pin_memory: bool = True
86
+
87
+ # Augmentation
88
+ random_flip: bool = True
89
+ random_rotation: float = 5.0 # degrees
90
+ color_jitter: float = 0.1
91
+
92
+ # Procedure filtering
93
+ procedures: list[str] = field(default_factory=lambda: [
94
+ "rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic",
95
+ ])
96
+ intensity_range: tuple[float, float] = (30.0, 100.0)
97
+
98
+ # Data-driven displacement
99
+ displacement_model_path: str | None = None
100
+ noise_scale: float = 0.1
101
+
102
+
103
+ @dataclass
104
+ class InferenceConfig:
105
+ """Inference / generation configuration."""
106
+ num_inference_steps: int = 30
107
+ guidance_scale: float = 7.5
108
+ scheduler: str = "dpmsolver++" # "ddpm", "ddim", "dpmsolver++"
109
+ controlnet_conditioning_scale: float = 1.0
110
+
111
+ # Post-processing
112
+ use_neural_postprocess: bool = False
113
+ restore_mode: str = "codeformer"
114
+ codeformer_fidelity: float = 0.7
115
+ use_realesrgan: bool = True
116
+ use_laplacian_blend: bool = True
117
+ sharpen_strength: float = 0.25
118
+
119
+ # Identity verification
120
+ verify_identity: bool = True
121
+ identity_threshold: float = 0.6
122
+
123
+
124
+ @dataclass
125
+ class EvaluationConfig:
126
+ """Evaluation configuration."""
127
+ compute_fid: bool = True
128
+ compute_lpips: bool = True
129
+ compute_nme: bool = True
130
+ compute_identity: bool = True
131
+ compute_ssim: bool = True
132
+ stratify_fitzpatrick: bool = True
133
+ stratify_procedure: bool = True
134
+ max_eval_samples: int = 0 # 0 = all
135
+
136
+
137
+ @dataclass
138
+ class WandbConfig:
139
+ """Weights & Biases logging configuration."""
140
+ enabled: bool = True
141
+ project: str = "landmarkdiff"
142
+ entity: str | None = None
143
+ run_name: str | None = None
144
+ tags: list[str] = field(default_factory=list)
145
+
146
+
147
+ @dataclass
148
+ class SlurmConfig:
149
+ """SLURM job submission parameters."""
150
+ partition: str = "batch_gpu"
151
+ account: str = "csb_gpu_acc"
152
+ gpu_type: str = "nvidia_rtx_a6000"
153
+ num_gpus: int = 1
154
+ mem: str = "48G"
155
+ cpus_per_task: int = 8
156
+ time_limit: str = "48:00:00"
157
+ job_prefix: str = "surgery_"
158
+
159
+
160
+ @dataclass
161
+ class SafetyConfig:
162
+ """Clinical safety and responsible AI parameters."""
163
+ identity_threshold: float = 0.6
164
+ max_displacement_fraction: float = 0.05
165
+ watermark_enabled: bool = True
166
+ watermark_text: str = "AI-GENERATED PREDICTION"
167
+ ood_detection_enabled: bool = True
168
+ ood_confidence_threshold: float = 0.3
169
+ min_face_confidence: float = 0.5
170
+ max_yaw_degrees: float = 45.0
171
+
172
+
173
+ @dataclass
174
+ class ExperimentConfig:
175
+ """Top-level experiment configuration."""
176
+ experiment_name: str = "default"
177
+ description: str = ""
178
+ version: str = "0.3.0"
179
+
180
+ model: ModelConfig = field(default_factory=ModelConfig)
181
+ training: TrainingConfig = field(default_factory=TrainingConfig)
182
+ data: DataConfig = field(default_factory=DataConfig)
183
+ inference: InferenceConfig = field(default_factory=InferenceConfig)
184
+ evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
185
+ wandb: WandbConfig = field(default_factory=WandbConfig)
186
+ slurm: SlurmConfig = field(default_factory=SlurmConfig)
187
+ safety: SafetyConfig = field(default_factory=SafetyConfig)
188
+
189
+ # Output
190
+ output_dir: str = "outputs"
191
+
192
+ @classmethod
193
+ def from_yaml(cls, path: str | Path) -> ExperimentConfig:
194
+ """Load config from a YAML file."""
195
+ path = Path(path)
196
+ with open(path) as f:
197
+ raw = yaml.safe_load(f)
198
+
199
+ if raw is None:
200
+ return cls()
201
+
202
+ return cls(
203
+ experiment_name=raw.get("experiment_name", "default"),
204
+ description=raw.get("description", ""),
205
+ version=raw.get("version", "0.3.0"),
206
+ model=_from_dict(ModelConfig, raw.get("model", {})),
207
+ training=_from_dict(TrainingConfig, raw.get("training", {})),
208
+ data=_from_dict(DataConfig, raw.get("data", {})),
209
+ inference=_from_dict(InferenceConfig, raw.get("inference", {})),
210
+ evaluation=_from_dict(EvaluationConfig, raw.get("evaluation", {})),
211
+ wandb=_from_dict(WandbConfig, raw.get("wandb", {})),
212
+ slurm=_from_dict(SlurmConfig, raw.get("slurm", {})),
213
+ safety=_from_dict(SafetyConfig, raw.get("safety", {})),
214
+ output_dir=raw.get("output_dir", "outputs"),
215
+ )
216
+
217
+ def to_yaml(self, path: str | Path) -> None:
218
+ """Save config to a YAML file."""
219
+ path = Path(path)
220
+ path.parent.mkdir(parents=True, exist_ok=True)
221
+ d = _convert_tuples(asdict(self))
222
+ with open(path, "w") as f:
223
+ yaml.dump(d, f, default_flow_style=False, sort_keys=False)
224
+
225
+ def to_dict(self) -> dict:
226
+ """Convert to dictionary."""
227
+ return asdict(self)
228
+
229
+
230
+ def _from_dict(cls, d: dict):
231
+ """Create a dataclass from a dict, ignoring unknown keys."""
232
+ import dataclasses
233
+ field_map = {f.name: f for f in dataclasses.fields(cls)}
234
+ filtered = {}
235
+ for k, v in d.items():
236
+ if k not in field_map:
237
+ continue
238
+ # Convert lists back to tuples where the field type is tuple
239
+ f = field_map[k]
240
+ if isinstance(v, list) and "tuple" in str(f.type):
241
+ v = tuple(v)
242
+ filtered[k] = v
243
+ return cls(**filtered)
244
+
245
+
246
+ def _convert_tuples(obj):
247
+ """Recursively convert tuples to lists for YAML serialization."""
248
+ if isinstance(obj, dict):
249
+ return {k: _convert_tuples(v) for k, v in obj.items()}
250
+ if isinstance(obj, (list, tuple)):
251
+ return [_convert_tuples(item) for item in obj]
252
+ return obj
253
+
254
+
255
+ def load_config(
256
+ config_path: str | Path | None = None,
257
+ overrides: dict[str, object] | None = None,
258
+ ) -> ExperimentConfig:
259
+ """Load config with optional dot-notation overrides.
260
+
261
+ Args:
262
+ config_path: Path to YAML config. None returns defaults.
263
+ overrides: Dict of "section.key" -> value overrides.
264
+ E.g., {"training.learning_rate": 5e-6}
265
+
266
+ Returns:
267
+ ExperimentConfig with overrides applied.
268
+ """
269
+ if config_path:
270
+ config = ExperimentConfig.from_yaml(config_path)
271
+ else:
272
+ config = ExperimentConfig()
273
+
274
+ if overrides:
275
+ for key, value in overrides.items():
276
+ parts = key.split(".")
277
+ obj = config
278
+ for part in parts[:-1]:
279
+ if hasattr(obj, part):
280
+ obj = getattr(obj, part)
281
+ else:
282
+ break
283
+ if hasattr(obj, parts[-1]):
284
+ setattr(obj, parts[-1], value)
285
+
286
+ return config
287
+
288
+
289
+ def validate_config(config: ExperimentConfig) -> list[str]:
290
+ """Validate config and return list of warnings."""
291
+ warnings = []
292
+
293
+ if config.training.phase == "B" and not config.training.resume_from_checkpoint:
294
+ warnings.append("Phase B should resume from a Phase A checkpoint")
295
+
296
+ eff_batch = config.training.batch_size * config.training.gradient_accumulation_steps
297
+ if eff_batch < 8:
298
+ warnings.append(f"Effective batch size {eff_batch} < 8 may cause instability")
299
+
300
+ if config.training.learning_rate > 1e-4:
301
+ warnings.append("Learning rate > 1e-4 is unusually high for fine-tuning")
302
+
303
+ if config.data.image_size != 512:
304
+ warnings.append(f"Image size {config.data.image_size} != 512; SD1.5 expects 512")
305
+
306
+ if config.safety.identity_threshold < 0.3:
307
+ warnings.append("Identity threshold < 0.3 may pass poor quality outputs")
308
+
309
+ return warnings