Conditional GAN — MNIST Digit Generation
Conditional Generative Adversarial Network (CGAN) for handwritten digit synthesis, trained on the MNIST dataset.
Architecture
A standard GAN framework conditioned on class labels, featuring a Generator and a Discriminator network.
| Component | Details |
|---|---|
| Discriminator | Takes a 1×28×28 grayscale image and a class label (0-9). Embeds the label, expands it to 1×28×28, concatenates with the image, and passes through two downsampling Conv2d blocks (with BatchNorm2d and LeakyReLU). Output is flattened and passed through a Sigmoid activation (1 output). |
| Generator | Takes a 100-dimensional noise vector and a target class label (0-9). Uses embedding layers for the label, concatenates it with noise, and passes through two ConvTranspose2d upsampling blocks (with BatchNorm2d and LeakyReLU). Output uses a Tanh activation function (1×28×28). |
Loss Function
The model is trained using the standard Conditional GAN Minimax objective, implemented with Binary Cross Entropy:
Loss = min_G max_D V(D, G) = E_x[log D(x|y)] + E_z[log(1 - D(G(z|y)|y))]
| Loss | Role |
|---|---|
| Discriminator Loss | Penalises the Discriminator for incorrectly classifying real MNIST images as fake, or generated images as real. |
| Generator Loss | Penalises the Generator when the Discriminator successfully identifies its generated images as fake. |
Training
- Dataset: 60,000 train (MNIST dataset)
- Input size: 28×28, normalised to [-1, 1]
- Batch size: 32
- Epochs: 50
- Optimizer: Adam
- Hyperparameter search: Optuna over
learning_rate{1e-5 to 2e-3},beta1{0.0, 0.9}, andnoise_dim{50, 100, 128} - Best weights: snapshot of the epoch with the lowest Generator loss and a stable Discriminator loss. Best Hyperparameters found:
learning_rate= 0.000112,beta1= 0.037,noise_dim= 100.
Data Preprocessing (train only)
Photometric:
Images are converted to PyTorch tensors and normalised with a mean of 0.5 and standard deviation of 0.5, scaling the pixel values to the [-1, 1] range to match the Tanh activation of the Generator.
Usage
import torch
from torch import nn
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
class CGAN_Generator(nn.Module):
def __init__(self, noise_dim=100, num_classes=10, img_size=28):
super().__init__()
self.init_size = img_size // 4
self.embedding_dim = 20
self.label_embedding = nn.Embedding(num_classes, self.embedding_dim)
self.projection = nn.Sequential(
nn.Linear(noise_dim + self.embedding_dim, 128 * self.init_size * self.init_size)
)
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, noise, labels):
label_embed = self.label_embedding(labels)
merged_input = torch.cat((noise, label_embed), dim=1)
out = self.projection(merged_input)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
model = CGAN_Generator(noise_dim=100)
weights_path = hf_hub_download(repo_id="VioletaR/cgan-mnist", filename="mnist_cgan_generator.pth")
model.load_state_dict(torch.load(weights_path))
model.eval()
target_label = torch.tensor([7])
eval_noise = torch.randn(1, noise_dim)
with torch.no_grad():
generated_img = model(eval_noise, target_label)
generated_img = (generated_img + 1) / 2.0
img_numpy = generated_img.squeeze().cpu().numpy()
plt.imshow(img_numpy, cmap='gray')
plt.title(f"Generated Label: {target_label.item()}")
plt.axis('off')
plt.show()
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support