Pokemon GAN — Spectral Norm & Hinge Loss

A Generative Adversarial Network (GAN) trained to synthesize 64x64 pixel-art style Pokemon sprites. This model was trained on the huggan/pokemon dataset using Optuna for hyperparameter optimization.

Model Architecture

This model utilizes a custom DCGAN-style framework with several modern stability improvements:

Component Details
Discriminator A convolutional neural network that processes 64x64x3 RGB images down to a 1x1 scalar. It applies Spectral Normalization to every Conv2d layer to enforce Lipschitz continuity and stabilize training, paired with LeakyReLU activations.
Generator Takes a latent noise vector (1x1) and projects it using a ConvTranspose2d layer. To prevent checkerboard artifacts, the upsampling path uses Upsampling followed by standard Conv2d layers, BatchNorm2d, and ReLU. The final output uses a Tanh activation to scale pixels to [-1, 1].

Training Details

  • Dataset: huggan/pokemon
  • Resolution: 64x64 RGB
  • Data Augmentations: Random Affine Translation (10%), Random Horizontal Flip (p=0.5).
  • Loss Function: Hinge Loss
    • Discriminator: ReLU(1.0 - D(real)) + ReLU(1.0 + D(fake))
    • Generator: -D(fake)
  • Optimizer: Adam

Usage

import torch
from torch import nn
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download

class PokemonGenerator(nn.Module):
    def __init__(self, noise_dim=100, features_g=64, channels=3):
        super().__init__()
        
        self.initial_block = nn.Sequential(
            nn.ConvTranspose2d(noise_dim, features_g * 8, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(True)
        )
        
        self.upsample_blocks = nn.Sequential(
            # 4x4 -> 8x8
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(features_g * 8, features_g * 4, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(True),
            
            # 8x8 -> 16x16
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(features_g * 4, features_g * 2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(True),
            
            # 16x16 -> 32x32
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(features_g * 2, features_g, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(features_g),
            nn.ReLU(True),
            
            # 32x32 -> 64x64 RGB
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(features_g, channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        out = self.initial_block(x)
        return self.upsample_blocks(out)

model = PokemonGenerator(noise_dim=128)

weights_path = hf_hub_download(repo_id="VioletaR/pokemon-gan", filename="pokemon_generator.pth")
model.load_state_dict(torch.load(weights_path))
model.eval()

# Note: The generator expects a 4D tensor (Batch, Channels, Height, Width)
eval_noise = torch.randn(1, noise_dim, 1, 1)

with torch.no_grad():
    generated_img = model(eval_noise)

generated_img = (generated_img + 1) / 2.0  # Scale from [-1, 1] to [0, 1]
img_numpy = generated_img.squeeze().permute(1, 2, 0).cpu().numpy()

plt.imshow(img_numpy)
plt.title("Generated Pokemon")
plt.axis('off')
plt.show()
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train violetar/pokemon-gan