Spaces:
Running
Running
| """Ensemble inference for improved output quality. | |
| Generates multiple outputs with different random seeds and combines them | |
| to reduce per-sample variance. Supports multiple aggregation strategies: | |
| - Pixel-space averaging (fast, slight blur) | |
| - Feature-space averaging (better quality, requires VAE encode) | |
| - Best-of-N selection (picks output with highest identity similarity) | |
| Usage: | |
| from landmarkdiff.ensemble import EnsembleInference | |
| ensemble = EnsembleInference( | |
| mode="controlnet", | |
| controlnet_checkpoint="checkpoints/final/controlnet_ema", | |
| n_samples=5, | |
| strategy="best_of_n", | |
| ) | |
| ensemble.load() | |
| result = ensemble.generate(image, procedure="rhinoplasty", intensity=65) | |
| """ | |
| from __future__ import annotations | |
| import cv2 | |
| import numpy as np | |
| class EnsembleInference: | |
| """Multi-sample ensemble inference for LandmarkDiff. | |
| Generates N outputs with different seeds and combines them using | |
| the specified aggregation strategy. | |
| """ | |
| def __init__( | |
| self, | |
| mode: str = "controlnet", | |
| controlnet_checkpoint: str | None = None, | |
| displacement_model_path: str | None = None, | |
| n_samples: int = 5, | |
| strategy: str = "best_of_n", | |
| base_seed: int = 42, | |
| **pipeline_kwargs, | |
| ): | |
| """Initialize ensemble inference. | |
| Args: | |
| mode: Pipeline mode (controlnet, img2img, tps). | |
| controlnet_checkpoint: Path to fine-tuned ControlNet. | |
| displacement_model_path: Path to displacement model. | |
| n_samples: Number of ensemble members. | |
| strategy: Aggregation strategy: | |
| - "pixel_average": Average in pixel space. | |
| - "weighted_average": Weighted by quality metrics. | |
| - "best_of_n": Select best by identity similarity. | |
| - "median": Pixel-wise median (robust to outliers). | |
| base_seed: Base random seed (each sample uses base_seed + i). | |
| **pipeline_kwargs: Additional kwargs for LandmarkDiffPipeline. | |
| """ | |
| self.mode = mode | |
| self.controlnet_checkpoint = controlnet_checkpoint | |
| self.displacement_model_path = displacement_model_path | |
| self.n_samples = n_samples | |
| self.strategy = strategy | |
| self.base_seed = base_seed | |
| self.pipeline_kwargs = pipeline_kwargs | |
| self._pipeline = None | |
| def load(self) -> None: | |
| """Load the inference pipeline.""" | |
| from landmarkdiff.inference import LandmarkDiffPipeline | |
| self._pipeline = LandmarkDiffPipeline( | |
| mode=self.mode, | |
| controlnet_checkpoint=self.controlnet_checkpoint, | |
| displacement_model_path=self.displacement_model_path, | |
| **self.pipeline_kwargs, | |
| ) | |
| self._pipeline.load() | |
| def is_loaded(self) -> bool: | |
| return self._pipeline is not None and self._pipeline.is_loaded | |
| def generate( | |
| self, | |
| image: np.ndarray, | |
| procedure: str = "rhinoplasty", | |
| intensity: float = 50.0, | |
| num_inference_steps: int = 30, | |
| guidance_scale: float = 9.0, | |
| controlnet_conditioning_scale: float = 0.9, | |
| strength: float = 0.5, | |
| seed: int | None = None, | |
| **kwargs, | |
| ) -> dict: | |
| """Generate ensemble output. | |
| Returns: | |
| Dict with keys: | |
| - output: Final ensembled image (np.ndarray, BGR, uint8) | |
| - outputs: List of all individual outputs | |
| - scores: Quality scores for each sample | |
| - selected_idx: Index of selected sample (for best_of_n) | |
| - strategy: Aggregation strategy used | |
| - n_samples: Number of ensemble members | |
| """ | |
| if not self.is_loaded: | |
| raise RuntimeError("Pipeline not loaded. Call load() first.") | |
| base = seed if seed is not None else self.base_seed | |
| outputs = [] | |
| results = [] | |
| # Generate N samples | |
| for i in range(self.n_samples): | |
| sample_seed = base + i | |
| result = self._pipeline.generate( | |
| image, | |
| procedure=procedure, | |
| intensity=intensity, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| strength=strength, | |
| seed=sample_seed, | |
| **kwargs, | |
| ) | |
| outputs.append(result["output"]) | |
| results.append(result) | |
| # Aggregate | |
| if self.strategy == "pixel_average": | |
| final = self._pixel_average(outputs) | |
| scores = [1.0 / self.n_samples] * self.n_samples | |
| selected_idx = -1 | |
| elif self.strategy == "weighted_average": | |
| final, scores = self._weighted_average(outputs, image) | |
| selected_idx = -1 | |
| elif self.strategy == "best_of_n": | |
| final, scores, selected_idx = self._best_of_n(outputs, image) | |
| elif self.strategy == "median": | |
| final = self._pixel_median(outputs) | |
| scores = [1.0 / self.n_samples] * self.n_samples | |
| selected_idx = -1 | |
| else: | |
| raise ValueError(f"Unknown strategy: {self.strategy}") | |
| # Copy metadata from best result | |
| best_idx = selected_idx if selected_idx >= 0 else 0 | |
| ensemble_result = dict(results[best_idx]) | |
| ensemble_result.update({ | |
| "output": final, | |
| "outputs": outputs, | |
| "scores": scores, | |
| "selected_idx": selected_idx, | |
| "strategy": self.strategy, | |
| "n_samples": self.n_samples, | |
| }) | |
| return ensemble_result | |
| def _pixel_average(self, outputs: list[np.ndarray]) -> np.ndarray: | |
| """Simple pixel-space averaging.""" | |
| stacked = np.stack(outputs, axis=0).astype(np.float32) | |
| return np.clip(stacked.mean(axis=0), 0, 255).astype(np.uint8) | |
| def _pixel_median(self, outputs: list[np.ndarray]) -> np.ndarray: | |
| """Pixel-wise median (robust to outliers).""" | |
| stacked = np.stack(outputs, axis=0) | |
| return np.median(stacked, axis=0).astype(np.uint8) | |
| def _weighted_average( | |
| self, | |
| outputs: list[np.ndarray], | |
| reference: np.ndarray, | |
| ) -> tuple[np.ndarray, list[float]]: | |
| """Quality-weighted averaging using SSIM as weight.""" | |
| from landmarkdiff.evaluation import compute_ssim | |
| # Compute SSIM of each output to reference | |
| scores = [] | |
| for output in outputs: | |
| ssim = compute_ssim(output, reference) | |
| scores.append(float(ssim)) | |
| # Normalize to weights (higher SSIM = higher weight, clamp negatives) | |
| clamped = [max(0.0, s) for s in scores] | |
| total = sum(clamped) or 1.0 | |
| weights = [s / total for s in clamped] | |
| # Weighted average | |
| result = np.zeros_like(outputs[0], dtype=np.float32) | |
| for output, weight in zip(outputs, weights): | |
| result += output.astype(np.float32) * weight | |
| return np.clip(result, 0, 255).astype(np.uint8), scores | |
| def _best_of_n( | |
| self, | |
| outputs: list[np.ndarray], | |
| reference: np.ndarray, | |
| ) -> tuple[np.ndarray, list[float], int]: | |
| """Select the output with highest identity similarity to reference.""" | |
| from landmarkdiff.evaluation import compute_identity_similarity | |
| scores = [] | |
| for output in outputs: | |
| sim = compute_identity_similarity(output, reference) | |
| scores.append(float(sim)) | |
| best_idx = int(np.argmax(scores)) | |
| return outputs[best_idx], scores, best_idx | |
| def ensemble_inference( | |
| image_path: str, | |
| procedure: str = "rhinoplasty", | |
| intensity: float = 65.0, | |
| output_dir: str = "ensemble_output", | |
| n_samples: int = 5, | |
| strategy: str = "best_of_n", | |
| mode: str = "tps", | |
| controlnet_checkpoint: str | None = None, | |
| displacement_model_path: str | None = None, | |
| seed: int = 42, | |
| ) -> None: | |
| """CLI entry point for ensemble inference.""" | |
| from pathlib import Path | |
| out = Path(output_dir) | |
| out.mkdir(parents=True, exist_ok=True) | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| print(f"ERROR: Cannot read image: {image_path}") | |
| return | |
| image = cv2.resize(image, (512, 512)) | |
| ensemble = EnsembleInference( | |
| mode=mode, | |
| controlnet_checkpoint=controlnet_checkpoint, | |
| displacement_model_path=displacement_model_path, | |
| n_samples=n_samples, | |
| strategy=strategy, | |
| base_seed=seed, | |
| ) | |
| ensemble.load() | |
| print(f"Generating ensemble ({n_samples} samples, strategy={strategy})...") | |
| result = ensemble.generate( | |
| image, | |
| procedure=procedure, | |
| intensity=intensity, | |
| seed=seed, | |
| ) | |
| # Save outputs | |
| cv2.imwrite(str(out / "ensemble_output.png"), result["output"]) | |
| cv2.imwrite(str(out / "original.png"), image) | |
| # Save individual samples | |
| for i, output in enumerate(result["outputs"]): | |
| cv2.imwrite(str(out / f"sample_{i:02d}.png"), output) | |
| score = result["scores"][i] | |
| print(f" Sample {i}: score={score:.4f}" | |
| + (" <-- selected" if i == result.get("selected_idx") else "")) | |
| # Comparison grid | |
| panels = [image] + result["outputs"] + [result["output"]] | |
| # Resize to 256 for compact grid | |
| panels_small = [cv2.resize(p, (256, 256)) for p in panels] | |
| grid = np.hstack(panels_small) | |
| cv2.imwrite(str(out / "comparison_grid.png"), grid) | |
| print(f"\nEnsemble output saved: {out / 'ensemble_output.png'}") | |
| if result.get("selected_idx", -1) >= 0: | |
| print(f"Selected sample: {result['selected_idx']} " | |
| f"(score={result['scores'][result['selected_idx']]:.4f})") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Ensemble inference") | |
| parser.add_argument("image", help="Input face image") | |
| parser.add_argument("--procedure", default="rhinoplasty") | |
| parser.add_argument("--intensity", type=float, default=65.0) | |
| parser.add_argument("--output", default="ensemble_output") | |
| parser.add_argument("--n_samples", type=int, default=5) | |
| parser.add_argument("--strategy", default="best_of_n", | |
| choices=["pixel_average", "weighted_average", "best_of_n", "median"]) | |
| parser.add_argument("--mode", default="tps", | |
| choices=["controlnet", "img2img", "tps"]) | |
| parser.add_argument("--checkpoint", default=None) | |
| parser.add_argument("--displacement-model", default=None) | |
| parser.add_argument("--seed", type=int, default=42) | |
| args = parser.parse_args() | |
| ensemble_inference( | |
| args.image, args.procedure, args.intensity, | |
| args.output, args.n_samples, args.strategy, | |
| args.mode, args.checkpoint, args.displacement_model, | |
| args.seed, | |
| ) | |