dreamlessx's picture
Update landmarkdiff/fid.py to v0.3.2
28dc803 verified
"""Self-contained FID computation using InceptionV3 feature extraction.
Avoids dependency on torch-fidelity by implementing FID directly.
Supports GPU acceleration, batched processing, and caching.
Usage:
from landmarkdiff.fid import compute_fid_from_dirs, compute_fid_from_arrays
# From directories
fid = compute_fid_from_dirs("path/to/real", "path/to/generated")
# From numpy arrays
fid = compute_fid_from_arrays(real_images, generated_images)
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
try:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
Dataset = object # type: ignore[misc,assignment]
def _load_inception_v3() -> Any:
"""Load InceptionV3 with pool3 features (2048-dim)."""
from torchvision.models import Inception_V3_Weights, inception_v3
model = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
# We want features from the avg pool layer (2048-dim)
# Remove the final FC layer
model.fc = nn.Identity()
model.eval()
return model
class ImageFolderDataset(Dataset):
"""Simple dataset that loads images from a directory."""
def __init__(self, directory: str | Path, image_size: int = 299):
self.directory = Path(directory)
exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
self.files = sorted(
f for f in self.directory.iterdir()
if f.suffix.lower() in exts and f.is_file()
)
self.image_size = image_size
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, idx: int) -> Any:
import cv2
img = cv2.imread(str(self.files[idx]))
if img is None:
# Return zeros if image can't be loaded
return torch.zeros(3, self.image_size, self.image_size)
img = cv2.resize(img, (self.image_size, self.image_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Normalize to [0, 1] then ImageNet normalize
t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
t = _imagenet_normalize(t)
return t
class NumpyArrayDataset(Dataset):
"""Dataset wrapping a list of numpy arrays."""
def __init__(self, images: list[np.ndarray], image_size: int = 299):
self.images = images
self.image_size = image_size
def __len__(self) -> int:
return len(self.images)
def __getitem__(self, idx: int) -> Any:
import cv2
img = self.images[idx]
if img.shape[:2] != (self.image_size, self.image_size):
img = cv2.resize(img, (self.image_size, self.image_size))
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4:
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
elif img.shape[2] == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
t = _imagenet_normalize(t)
return t
def _imagenet_normalize(t: torch.Tensor) -> torch.Tensor:
"""Apply ImageNet normalization."""
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
return (t - mean) / std
def _extract_features(
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
) -> np.ndarray:
"""Extract InceptionV3 pool3 features from a dataloader."""
features = []
with torch.no_grad():
for batch in dataloader:
batch = batch.to(device)
feat = model(batch)
if isinstance(feat, tuple):
feat = feat[0]
features.append(feat.cpu().numpy())
return np.concatenate(features, axis=0)
def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Compute mean and covariance of feature vectors."""
if features.shape[0] < 2:
raise ValueError(
f"FID requires at least 2 images, got {features.shape[0]}"
)
mu = np.mean(features, axis=0)
sigma = np.cov(features, rowvar=False)
return mu, sigma
def _calculate_fid(
mu1: np.ndarray, sigma1: np.ndarray,
mu2: np.ndarray, sigma2: np.ndarray,
) -> float:
"""Calculate FID given two sets of statistics.
FID = ||mu1 - mu2||^2 + Tr(sigma1 + sigma2 - 2*sqrt(sigma1*sigma2))
"""
from scipy.linalg import sqrtm
diff = mu1 - mu2
covmean = sqrtm(sigma1 @ sigma2)
# Handle numerical instability
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
return float(max(fid, 0.0))
def compute_fid_from_dirs(
real_dir: str | Path,
generated_dir: str | Path,
batch_size: int = 32,
num_workers: int = 4,
device: str | None = None,
) -> float:
"""Compute FID between two directories of images.
Args:
real_dir: Path to real images.
generated_dir: Path to generated images.
batch_size: Batch size for feature extraction.
num_workers: DataLoader workers.
device: "cuda" or "cpu". Auto-detects if None.
Returns:
FID score (lower = better).
"""
if not HAS_TORCH:
raise ImportError("PyTorch required for FID computation")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
dev = torch.device(device)
model = _load_inception_v3().to(dev)
real_ds = ImageFolderDataset(real_dir)
gen_ds = ImageFolderDataset(generated_dir)
if len(real_ds) == 0 or len(gen_ds) == 0:
raise ValueError("Need at least 1 image in each directory")
real_loader = DataLoader(real_ds, batch_size=batch_size,
num_workers=num_workers, pin_memory=True)
gen_loader = DataLoader(gen_ds, batch_size=batch_size,
num_workers=num_workers, pin_memory=True)
real_features = _extract_features(model, real_loader, dev)
gen_features = _extract_features(model, gen_loader, dev)
mu_real, sigma_real = _compute_statistics(real_features)
mu_gen, sigma_gen = _compute_statistics(gen_features)
return _calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
def compute_fid_from_arrays(
real_images: list[np.ndarray],
generated_images: list[np.ndarray],
batch_size: int = 32,
device: str | None = None,
) -> float:
"""Compute FID from lists of numpy arrays.
Args:
real_images: List of (H, W, 3) BGR uint8 images.
generated_images: List of (H, W, 3) BGR uint8 images.
batch_size: Batch size for feature extraction.
device: "cuda" or "cpu".
Returns:
FID score (lower = better).
"""
if not HAS_TORCH:
raise ImportError("PyTorch required for FID computation")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
dev = torch.device(device)
model = _load_inception_v3().to(dev)
real_ds = NumpyArrayDataset(real_images)
gen_ds = NumpyArrayDataset(generated_images)
real_loader = DataLoader(real_ds, batch_size=batch_size, num_workers=0)
gen_loader = DataLoader(gen_ds, batch_size=batch_size, num_workers=0)
real_features = _extract_features(model, real_loader, dev)
gen_features = _extract_features(model, gen_loader, dev)
mu_real, sigma_real = _compute_statistics(real_features)
mu_gen, sigma_gen = _compute_statistics(gen_features)
return _calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)