Spaces:
Running
Running
File size: 7,691 Bytes
6421899 433e26f 6421899 28dc803 6421899 433e26f 6421899 433e26f 6421899 28dc803 6421899 28dc803 6421899 28dc803 433e26f 28dc803 6421899 28dc803 6421899 28dc803 6421899 28dc803 6421899 28dc803 6421899 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 | """Self-contained FID computation using InceptionV3 feature extraction.
Avoids dependency on torch-fidelity by implementing FID directly.
Supports GPU acceleration, batched processing, and caching.
Usage:
from landmarkdiff.fid import compute_fid_from_dirs, compute_fid_from_arrays
# From directories
fid = compute_fid_from_dirs("path/to/real", "path/to/generated")
# From numpy arrays
fid = compute_fid_from_arrays(real_images, generated_images)
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
try:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
Dataset = object # type: ignore[misc,assignment]
def _load_inception_v3() -> Any:
"""Load InceptionV3 with pool3 features (2048-dim)."""
from torchvision.models import Inception_V3_Weights, inception_v3
model = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
# We want features from the avg pool layer (2048-dim)
# Remove the final FC layer
model.fc = nn.Identity()
model.eval()
return model
class ImageFolderDataset(Dataset):
"""Simple dataset that loads images from a directory."""
def __init__(self, directory: str | Path, image_size: int = 299):
self.directory = Path(directory)
exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
self.files = sorted(
f for f in self.directory.iterdir()
if f.suffix.lower() in exts and f.is_file()
)
self.image_size = image_size
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, idx: int) -> Any:
import cv2
img = cv2.imread(str(self.files[idx]))
if img is None:
# Return zeros if image can't be loaded
return torch.zeros(3, self.image_size, self.image_size)
img = cv2.resize(img, (self.image_size, self.image_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Normalize to [0, 1] then ImageNet normalize
t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
t = _imagenet_normalize(t)
return t
class NumpyArrayDataset(Dataset):
"""Dataset wrapping a list of numpy arrays."""
def __init__(self, images: list[np.ndarray], image_size: int = 299):
self.images = images
self.image_size = image_size
def __len__(self) -> int:
return len(self.images)
def __getitem__(self, idx: int) -> Any:
import cv2
img = self.images[idx]
if img.shape[:2] != (self.image_size, self.image_size):
img = cv2.resize(img, (self.image_size, self.image_size))
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4:
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
elif img.shape[2] == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
t = _imagenet_normalize(t)
return t
def _imagenet_normalize(t: torch.Tensor) -> torch.Tensor:
"""Apply ImageNet normalization."""
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
return (t - mean) / std
def _extract_features(
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
) -> np.ndarray:
"""Extract InceptionV3 pool3 features from a dataloader."""
features = []
with torch.no_grad():
for batch in dataloader:
batch = batch.to(device)
feat = model(batch)
if isinstance(feat, tuple):
feat = feat[0]
features.append(feat.cpu().numpy())
return np.concatenate(features, axis=0)
def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Compute mean and covariance of feature vectors."""
if features.shape[0] < 2:
raise ValueError(
f"FID requires at least 2 images, got {features.shape[0]}"
)
mu = np.mean(features, axis=0)
sigma = np.cov(features, rowvar=False)
return mu, sigma
def _calculate_fid(
mu1: np.ndarray, sigma1: np.ndarray,
mu2: np.ndarray, sigma2: np.ndarray,
) -> float:
"""Calculate FID given two sets of statistics.
FID = ||mu1 - mu2||^2 + Tr(sigma1 + sigma2 - 2*sqrt(sigma1*sigma2))
"""
from scipy.linalg import sqrtm
diff = mu1 - mu2
covmean = sqrtm(sigma1 @ sigma2)
# Handle numerical instability
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
return float(max(fid, 0.0))
def compute_fid_from_dirs(
real_dir: str | Path,
generated_dir: str | Path,
batch_size: int = 32,
num_workers: int = 4,
device: str | None = None,
) -> float:
"""Compute FID between two directories of images.
Args:
real_dir: Path to real images.
generated_dir: Path to generated images.
batch_size: Batch size for feature extraction.
num_workers: DataLoader workers.
device: "cuda" or "cpu". Auto-detects if None.
Returns:
FID score (lower = better).
"""
if not HAS_TORCH:
raise ImportError("PyTorch required for FID computation")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
dev = torch.device(device)
model = _load_inception_v3().to(dev)
real_ds = ImageFolderDataset(real_dir)
gen_ds = ImageFolderDataset(generated_dir)
if len(real_ds) == 0 or len(gen_ds) == 0:
raise ValueError("Need at least 1 image in each directory")
real_loader = DataLoader(real_ds, batch_size=batch_size,
num_workers=num_workers, pin_memory=True)
gen_loader = DataLoader(gen_ds, batch_size=batch_size,
num_workers=num_workers, pin_memory=True)
real_features = _extract_features(model, real_loader, dev)
gen_features = _extract_features(model, gen_loader, dev)
mu_real, sigma_real = _compute_statistics(real_features)
mu_gen, sigma_gen = _compute_statistics(gen_features)
return _calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
def compute_fid_from_arrays(
real_images: list[np.ndarray],
generated_images: list[np.ndarray],
batch_size: int = 32,
device: str | None = None,
) -> float:
"""Compute FID from lists of numpy arrays.
Args:
real_images: List of (H, W, 3) BGR uint8 images.
generated_images: List of (H, W, 3) BGR uint8 images.
batch_size: Batch size for feature extraction.
device: "cuda" or "cpu".
Returns:
FID score (lower = better).
"""
if not HAS_TORCH:
raise ImportError("PyTorch required for FID computation")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
dev = torch.device(device)
model = _load_inception_v3().to(dev)
real_ds = NumpyArrayDataset(real_images)
gen_ds = NumpyArrayDataset(generated_images)
real_loader = DataLoader(real_ds, batch_size=batch_size, num_workers=0)
gen_loader = DataLoader(gen_ds, batch_size=batch_size, num_workers=0)
real_features = _extract_features(model, real_loader, dev)
gen_features = _extract_features(model, gen_loader, dev)
mu_real, sigma_real = _compute_statistics(real_features)
mu_gen, sigma_gen = _compute_statistics(gen_features)
return _calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
|