Pokemon GAN — Spectral Norm & Hinge Loss
A Generative Adversarial Network (GAN) trained to synthesize 64x64 pixel-art style Pokemon sprites. This model was trained on the huggan/pokemon dataset using Optuna for hyperparameter optimization.
Model Architecture
This model utilizes a custom DCGAN-style framework with several modern stability improvements:
| Component | Details |
|---|---|
| Discriminator | A convolutional neural network that processes 64x64x3 RGB images down to a 1x1 scalar. It applies Spectral Normalization to every Conv2d layer to enforce Lipschitz continuity and stabilize training, paired with LeakyReLU activations. |
| Generator | Takes a latent noise vector (1x1) and projects it using a ConvTranspose2d layer. To prevent checkerboard artifacts, the upsampling path uses Upsampling followed by standard Conv2d layers, BatchNorm2d, and ReLU. The final output uses a Tanh activation to scale pixels to [-1, 1]. |
Training Details
- Dataset: huggan/pokemon
- Resolution: 64x64 RGB
- Data Augmentations: Random Affine Translation (10%), Random Horizontal Flip (p=0.5).
- Loss Function: Hinge Loss
- Discriminator:
ReLU(1.0 - D(real)) + ReLU(1.0 + D(fake)) - Generator:
-D(fake)
- Discriminator:
- Optimizer: Adam
Usage
import torch
from torch import nn
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
class PokemonGenerator(nn.Module):
def __init__(self, noise_dim=100, features_g=64, channels=3):
super().__init__()
self.initial_block = nn.Sequential(
nn.ConvTranspose2d(noise_dim, features_g * 8, kernel_size=4, stride=1, padding=0),
nn.BatchNorm2d(features_g * 8),
nn.ReLU(True)
)
self.upsample_blocks = nn.Sequential(
# 4x4 -> 8x8
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(features_g * 8, features_g * 4, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(True),
# 8x8 -> 16x16
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(features_g * 4, features_g * 2, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(True),
# 16x16 -> 32x32
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(features_g * 2, features_g, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(features_g),
nn.ReLU(True),
# 32x32 -> 64x64 RGB
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(features_g, channels, kernel_size=3, stride=1, padding=1),
nn.Tanh()
)
def forward(self, x):
out = self.initial_block(x)
return self.upsample_blocks(out)
model = PokemonGenerator(noise_dim=128)
weights_path = hf_hub_download(repo_id="VioletaR/pokemon-gan", filename="pokemon_generator.pth")
model.load_state_dict(torch.load(weights_path))
model.eval()
# Note: The generator expects a 4D tensor (Batch, Channels, Height, Width)
eval_noise = torch.randn(1, noise_dim, 1, 1)
with torch.no_grad():
generated_img = model(eval_noise)
generated_img = (generated_img + 1) / 2.0 # Scale from [-1, 1] to [0, 1]
img_numpy = generated_img.squeeze().permute(1, 2, 0).cpu().numpy()
plt.imshow(img_numpy)
plt.title("Generated Pokemon")
plt.axis('off')
plt.show()
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support