Conditional GAN — MNIST Digit Generation

Conditional Generative Adversarial Network (CGAN) for handwritten digit synthesis, trained on the MNIST dataset.

Architecture

A standard GAN framework conditioned on class labels, featuring a Generator and a Discriminator network.

Component Details
Discriminator Takes a 1×28×28 grayscale image and a class label (0-9). Embeds the label, expands it to 1×28×28, concatenates with the image, and passes through two downsampling Conv2d blocks (with BatchNorm2d and LeakyReLU). Output is flattened and passed through a Sigmoid activation (1 output).
Generator Takes a 100-dimensional noise vector and a target class label (0-9). Uses embedding layers for the label, concatenates it with noise, and passes through two ConvTranspose2d upsampling blocks (with BatchNorm2d and LeakyReLU). Output uses a Tanh activation function (1×28×28).

Loss Function

The model is trained using the standard Conditional GAN Minimax objective, implemented with Binary Cross Entropy:

Loss = min_G max_D V(D, G) = E_x[log D(x|y)] + E_z[log(1 - D(G(z|y)|y))]
Loss Role
Discriminator Loss Penalises the Discriminator for incorrectly classifying real MNIST images as fake, or generated images as real.
Generator Loss Penalises the Generator when the Discriminator successfully identifies its generated images as fake.

Training

  • Dataset: 60,000 train (MNIST dataset)
  • Input size: 28×28, normalised to [-1, 1]
  • Batch size: 32
  • Epochs: 50
  • Optimizer: Adam
  • Hyperparameter search: Optuna over learning_rate {1e-5 to 2e-3}, beta1 {0.0, 0.9}, and noise_dim {50, 100, 128}
  • Best weights: snapshot of the epoch with the lowest Generator loss and a stable Discriminator loss. Best Hyperparameters found: learning_rate = 0.000112, beta1 = 0.037, noise_dim = 100.

Data Preprocessing (train only)

Photometric: Images are converted to PyTorch tensors and normalised with a mean of 0.5 and standard deviation of 0.5, scaling the pixel values to the [-1, 1] range to match the Tanh activation of the Generator.

Usage

import torch
from torch import nn
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt

class CGAN_Generator(nn.Module):
    def __init__(self, noise_dim=100, num_classes=10, img_size=28):
        super().__init__()
        self.init_size = img_size // 4
        self.embedding_dim = 20
        self.label_embedding = nn.Embedding(num_classes, self.embedding_dim)
        
        self.projection = nn.Sequential(
            nn.Linear(noise_dim + self.embedding_dim, 128 * self.init_size * self.init_size)
        )
        
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        label_embed = self.label_embedding(labels)
        merged_input = torch.cat((noise, label_embed), dim=1)
        out = self.projection(merged_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

model = CGAN_Generator(noise_dim=100)

weights_path = hf_hub_download(repo_id="VioletaR/cgan-mnist", filename="mnist_cgan_generator.pth")
model.load_state_dict(torch.load(weights_path))
model.eval()

target_label = torch.tensor([7]) 
eval_noise = torch.randn(1, noise_dim)

with torch.no_grad():
    generated_img = model(eval_noise, target_label)

generated_img = (generated_img + 1) / 2.0
img_numpy = generated_img.squeeze().cpu().numpy()

plt.imshow(img_numpy, cmap='gray')
plt.title(f"Generated Label: {target_label.item()}")
plt.axis('off')
plt.show()
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train violetar/mnist-conditional-gan