File size: 7,691 Bytes
6421899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
6421899
 
 
 
 
 
 
 
 
 
28dc803
6421899
 
433e26f
6421899
433e26f
6421899
 
 
 
 
 
 
 
 
28dc803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6421899
28dc803
 
 
 
 
6421899
28dc803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
 
 
 
 
 
28dc803
6421899
 
 
 
28dc803
 
 
 
6421899
 
 
 
 
 
28dc803
 
6421899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28dc803
6421899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28dc803
 
 
 
6421899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
"""Self-contained FID computation using InceptionV3 feature extraction.

Avoids dependency on torch-fidelity by implementing FID directly.
Supports GPU acceleration, batched processing, and caching.

Usage:
    from landmarkdiff.fid import compute_fid_from_dirs, compute_fid_from_arrays

    # From directories
    fid = compute_fid_from_dirs("path/to/real", "path/to/generated")

    # From numpy arrays
    fid = compute_fid_from_arrays(real_images, generated_images)
"""

from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np

try:
    import torch
    import torch.nn as nn
    from torch.utils.data import DataLoader, Dataset
    HAS_TORCH = True
except ImportError:
    HAS_TORCH = False
    Dataset = object  # type: ignore[misc,assignment]


def _load_inception_v3() -> Any:
    """Load InceptionV3 with pool3 features (2048-dim)."""
    from torchvision.models import Inception_V3_Weights, inception_v3

    model = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
    # We want features from the avg pool layer (2048-dim)
    # Remove the final FC layer
    model.fc = nn.Identity()
    model.eval()
    return model


class ImageFolderDataset(Dataset):
    """Simple dataset that loads images from a directory."""

    def __init__(self, directory: str | Path, image_size: int = 299):
        self.directory = Path(directory)
        exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
        self.files = sorted(
            f for f in self.directory.iterdir()
            if f.suffix.lower() in exts and f.is_file()
        )
        self.image_size = image_size

    def __len__(self) -> int:
        return len(self.files)

    def __getitem__(self, idx: int) -> Any:
        import cv2
        img = cv2.imread(str(self.files[idx]))
        if img is None:
            # Return zeros if image can't be loaded
            return torch.zeros(3, self.image_size, self.image_size)
        img = cv2.resize(img, (self.image_size, self.image_size))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # Normalize to [0, 1] then ImageNet normalize
        t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
        t = _imagenet_normalize(t)
        return t


class NumpyArrayDataset(Dataset):
    """Dataset wrapping a list of numpy arrays."""

    def __init__(self, images: list[np.ndarray], image_size: int = 299):
        self.images = images
        self.image_size = image_size

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int) -> Any:
        import cv2
        img = self.images[idx]
        if img.shape[:2] != (self.image_size, self.image_size):
            img = cv2.resize(img, (self.image_size, self.image_size))
        if img.ndim == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        elif img.shape[2] == 4:
            img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
        elif img.shape[2] == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
        t = _imagenet_normalize(t)
        return t


def _imagenet_normalize(t: torch.Tensor) -> torch.Tensor:
    """Apply ImageNet normalization."""
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return (t - mean) / std


def _extract_features(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
) -> np.ndarray:
    """Extract InceptionV3 pool3 features from a dataloader."""
    features = []
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            feat = model(batch)
            if isinstance(feat, tuple):
                feat = feat[0]
            features.append(feat.cpu().numpy())
    return np.concatenate(features, axis=0)


def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """Compute mean and covariance of feature vectors."""
    if features.shape[0] < 2:
        raise ValueError(
            f"FID requires at least 2 images, got {features.shape[0]}"
        )
    mu = np.mean(features, axis=0)
    sigma = np.cov(features, rowvar=False)
    return mu, sigma


def _calculate_fid(
    mu1: np.ndarray, sigma1: np.ndarray,
    mu2: np.ndarray, sigma2: np.ndarray,
) -> float:
    """Calculate FID given two sets of statistics.

    FID = ||mu1 - mu2||^2 + Tr(sigma1 + sigma2 - 2*sqrt(sigma1*sigma2))
    """
    from scipy.linalg import sqrtm

    diff = mu1 - mu2
    covmean = sqrtm(sigma1 @ sigma2)

    # Handle numerical instability
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
    return float(max(fid, 0.0))


def compute_fid_from_dirs(
    real_dir: str | Path,
    generated_dir: str | Path,
    batch_size: int = 32,
    num_workers: int = 4,
    device: str | None = None,
) -> float:
    """Compute FID between two directories of images.

    Args:
        real_dir: Path to real images.
        generated_dir: Path to generated images.
        batch_size: Batch size for feature extraction.
        num_workers: DataLoader workers.
        device: "cuda" or "cpu". Auto-detects if None.

    Returns:
        FID score (lower = better).
    """
    if not HAS_TORCH:
        raise ImportError("PyTorch required for FID computation")

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    dev = torch.device(device)

    model = _load_inception_v3().to(dev)

    real_ds = ImageFolderDataset(real_dir)
    gen_ds = ImageFolderDataset(generated_dir)

    if len(real_ds) == 0 or len(gen_ds) == 0:
        raise ValueError("Need at least 1 image in each directory")

    real_loader = DataLoader(real_ds, batch_size=batch_size,
                              num_workers=num_workers, pin_memory=True)
    gen_loader = DataLoader(gen_ds, batch_size=batch_size,
                             num_workers=num_workers, pin_memory=True)

    real_features = _extract_features(model, real_loader, dev)
    gen_features = _extract_features(model, gen_loader, dev)

    mu_real, sigma_real = _compute_statistics(real_features)
    mu_gen, sigma_gen = _compute_statistics(gen_features)

    return _calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)


def compute_fid_from_arrays(
    real_images: list[np.ndarray],
    generated_images: list[np.ndarray],
    batch_size: int = 32,
    device: str | None = None,
) -> float:
    """Compute FID from lists of numpy arrays.

    Args:
        real_images: List of (H, W, 3) BGR uint8 images.
        generated_images: List of (H, W, 3) BGR uint8 images.
        batch_size: Batch size for feature extraction.
        device: "cuda" or "cpu".

    Returns:
        FID score (lower = better).
    """
    if not HAS_TORCH:
        raise ImportError("PyTorch required for FID computation")

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    dev = torch.device(device)

    model = _load_inception_v3().to(dev)

    real_ds = NumpyArrayDataset(real_images)
    gen_ds = NumpyArrayDataset(generated_images)

    real_loader = DataLoader(real_ds, batch_size=batch_size, num_workers=0)
    gen_loader = DataLoader(gen_ds, batch_size=batch_size, num_workers=0)

    real_features = _extract_features(model, real_loader, dev)
    gen_features = _extract_features(model, gen_loader, dev)

    mu_real, sigma_real = _compute_statistics(real_features)
    mu_gen, sigma_gen = _compute_statistics(gen_features)

    return _calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)