Spaces:
Running
Running
| """Curriculum learning support for progressive training difficulty. | |
| Implements a schedule that controls which training samples are used | |
| at different stages of training, starting with easy examples (small | |
| displacements) and gradually introducing harder ones. | |
| Usage in training loop:: | |
| curriculum = TrainingCurriculum( | |
| total_steps=100000, | |
| warmup_fraction=0.1, # first 10% easy only | |
| full_difficulty_at=0.5, # full dataset by 50% | |
| ) | |
| # In training loop: | |
| difficulty = curriculum.get_difficulty(global_step) | |
| # Use difficulty to filter/weight samples | |
| Or as a dataset wrapper:: | |
| dataset = CurriculumDataset( | |
| base_dataset=SyntheticPairDataset(data_dir), | |
| metadata_path=Path(data_dir) / "metadata.json", | |
| total_steps=100000, | |
| ) | |
| # Call dataset.set_step(global_step) each iteration | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import math | |
| from pathlib import Path | |
| import numpy as np | |
| class TrainingCurriculum: | |
| """Schedule that maps training step to difficulty level [0, 1]. | |
| Difficulty 0 = easiest (smallest displacements, lowest intensity). | |
| Difficulty 1 = full dataset (all difficulties). | |
| The schedule uses a cosine ramp: | |
| - During warmup: difficulty = 0 (easy only) | |
| - warmup → full_difficulty: cosine ramp from 0 → 1 | |
| - After full_difficulty: difficulty = 1 (full dataset) | |
| """ | |
| def __init__( | |
| self, | |
| total_steps: int, | |
| warmup_fraction: float = 0.1, | |
| full_difficulty_at: float = 0.5, | |
| ): | |
| self.total_steps = total_steps | |
| self.warmup_steps = int(total_steps * warmup_fraction) | |
| self.full_steps = int(total_steps * full_difficulty_at) | |
| def get_difficulty(self, step: int) -> float: | |
| """Get difficulty level [0, 1] for the given training step.""" | |
| if step < self.warmup_steps: | |
| return 0.0 | |
| if step >= self.full_steps: | |
| return 1.0 | |
| progress = (step - self.warmup_steps) / max(1, self.full_steps - self.warmup_steps) | |
| return 0.5 * (1 - math.cos(math.pi * progress)) | |
| def should_include( | |
| self, | |
| step: int, | |
| sample_difficulty: float, | |
| rng: np.random.Generator | None = None, | |
| ) -> bool: | |
| """Whether to include a sample of the given difficulty at this step. | |
| Uses probabilistic inclusion so harder samples gradually appear. | |
| Args: | |
| step: Current training step. | |
| sample_difficulty: Difficulty of the sample [0, 1]. | |
| rng: Random number generator for stochastic inclusion. | |
| Returns: | |
| True if sample should be used. | |
| """ | |
| curr_difficulty = self.get_difficulty(step) | |
| if sample_difficulty <= curr_difficulty: | |
| return True | |
| # Stochastic inclusion for samples slightly above threshold | |
| if rng is None: | |
| rng = np.random.default_rng() | |
| overshoot = sample_difficulty - curr_difficulty | |
| include_prob = max(0, 1.0 - overshoot * 5) # drops off quickly | |
| return rng.random() < include_prob | |
| class ProcedureCurriculum: | |
| """Procedure-aware curriculum that adjusts per-procedure weights. | |
| Some procedures are inherently harder (e.g., orthognathic with large | |
| deformations). This curriculum increases their weight over training. | |
| """ | |
| # Difficulty ranking (0=easiest, 1=hardest) | |
| DEFAULT_PROCEDURE_DIFFICULTY = { | |
| "blepharoplasty": 0.3, # small, localized changes | |
| "rhinoplasty": 0.5, # moderate, central face | |
| "rhytidectomy": 0.7, # large, affects face shape | |
| "orthognathic": 0.9, # largest deformations | |
| } | |
| def __init__( | |
| self, | |
| total_steps: int, | |
| procedure_difficulty: dict[str, float] | None = None, | |
| warmup_fraction: float = 0.1, | |
| ): | |
| self.curriculum = TrainingCurriculum(total_steps, warmup_fraction) | |
| self.proc_difficulty = procedure_difficulty or self.DEFAULT_PROCEDURE_DIFFICULTY | |
| def get_weight(self, step: int, procedure: str) -> float: | |
| """Get sampling weight for a procedure at the given step. | |
| Returns a value in [0.1, 1.0] — never fully excludes any procedure. | |
| """ | |
| difficulty = self.get_difficulty(step) | |
| proc_diff = self.proc_difficulty.get(procedure, 0.5) | |
| if proc_diff <= difficulty: | |
| return 1.0 | |
| # Reduce weight for too-hard procedures | |
| return max(0.1, 1.0 - (proc_diff - difficulty) * 2) | |
| def get_difficulty(self, step: int) -> float: | |
| return self.curriculum.get_difficulty(step) | |
| def get_procedure_weights(self, step: int) -> dict[str, float]: | |
| """Get all procedure weights at the given step.""" | |
| return { | |
| proc: self.get_weight(step, proc) | |
| for proc in self.proc_difficulty | |
| } | |
| def compute_sample_difficulty( | |
| metadata_path: str | Path, | |
| displacement_model_path: str | Path | None = None, | |
| ) -> dict[str, float]: | |
| """Compute difficulty scores for each sample in the dataset. | |
| Difficulty is based on: | |
| 1. Displacement intensity (from metadata) | |
| 2. Procedure difficulty | |
| 3. Source type (real > synthetic) | |
| Returns: | |
| Dict mapping sample prefix to difficulty score [0, 1]. | |
| """ | |
| with open(metadata_path) as f: | |
| meta = json.load(f) | |
| pairs = meta.get("pairs", {}) | |
| difficulties = {} | |
| proc_base = { | |
| "blepharoplasty": 0.2, | |
| "rhinoplasty": 0.4, | |
| "rhytidectomy": 0.6, | |
| "orthognathic": 0.8, | |
| "unknown": 0.5, | |
| } | |
| source_bonus = { | |
| "synthetic": 0.0, | |
| "synthetic_v3": 0.1, # realistic displacements slightly harder | |
| "real": 0.2, # real data hardest | |
| "augmented": 0.0, | |
| } | |
| for prefix, info in pairs.items(): | |
| proc = info.get("procedure", "unknown") | |
| source = info.get("source", "synthetic") | |
| intensity = info.get("intensity", 1.0) | |
| # Combine factors | |
| base = proc_base.get(proc, 0.5) | |
| src = source_bonus.get(source, 0.0) | |
| # Intensity scaling (higher intensity = harder) | |
| int_factor = min(1.0, intensity / 1.5) * 0.2 | |
| difficulties[prefix] = min(1.0, base + src + int_factor) | |
| return difficulties | |