Step-by-Step Implementation of Generative Adversarial Networks (GANs) in Python
Objective:
We will implement a Generative Adversarial Network (GAN) using TensorFlow and Keras to generate handwritten digits similar to the MNIST dataset.
Step 1: Install Dependencies
Ensure you have the required libraries installed. If not, install them using:
pip install tensorflow numpy matplotlib
Step 2: Import Libraries
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Flatten
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
Explanation:
- tensorflow and keras are used for building the GAN architecture.
- Dense is used to define fully connected layers.
- LeakyReLU helps prevent dying neurons in GANs.
- BatchNormalization stabilizes training.
- Reshape and Flatten help convert data between image and vector form.
- mnist dataset contains handwritten digits (0-9).
Step 3: Load and Preprocess the Data
# Load MNIST dataset
(X_train, _), (_, _) = mnist.load_data()
# Normalize data to [-1, 1] for better GAN performance
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=-1) # Add a channel dimension
# Define input shape
img_shape = (28, 28, 1)
latent_dim = 100 # Size of random noise vector
Explanation:
- The MNIST dataset consists of grayscale images of digits (28×28 pixels).
- Pixel values are normalized to [-1, 1] for better convergence.
- latent_dim=100 means we will use a 100-dimensional noise vector to generate images.
Step 4: Build the Generator
def build_generator():
model = Sequential([
Dense(256, input_dim=latent_dim),
LeakyReLU(alpha=0.2),
BatchNormalization(),
Dense(512),
LeakyReLU(alpha=0.2),
BatchNormalization(),
Dense(1024),
LeakyReLU(alpha=0.2),
BatchNormalization(),
Dense(28 * 28 * 1, activation='tanh'),
Reshape((28, 28, 1))
])
return model
generator = build_generator()
generator.summary()
Explanation:
- The generator takes random noise (100 values) as input.
- Dense(256) → Dense(512) → Dense(1024): Expands the latent space.
- LeakyReLU(alpha=0.2): Prevents neurons from becoming inactive.
- BatchNormalization(): Stabilizes training.
- Dense(28×28×1) → Reshape((28,28,1)): Converts the output into a 28×28 grayscale image.
Step 5: Build the Discriminator
def build_discriminator():
model = Sequential([
Flatten(input_shape=img_shape),
Dense(512),
LeakyReLU(alpha=0.2),
Dense(256),
LeakyReLU(alpha=0.2),
Dense(1, activation='sigmoid') # Output probability (real or fake)
])
return model
discriminator = build_discriminator()
discriminator.summary()
Explanation:
- The discriminator takes an image as input and classifies it as real or fake.
- Flatten(): Converts the 28×28 image into a 1D vector.
- Dense(512) → Dense(256): Extracts features.
- LeakyReLU(alpha=0.2): Prevents neuron death.
- Dense(1, activation='sigmoid'): Outputs 0 (fake) or 1 (real).
Step 6: Compile the Discriminator
discriminator.compile(optimizer=keras.optimizers.Adam(0.0002, 0.5), loss='binary_crossentropy', metrics=['accuracy'])
Explanation:
- The discriminator uses binary cross-entropy loss for real vs. fake classification.
- The Adam optimizer ensures smooth training.
Step 7: Build and Compile the GAN
# Freeze discriminator during GAN training
discriminator.trainable = False
# Build GAN by stacking generator and discriminator
gan = Sequential([generator, discriminator])
gan.compile(optimizer=keras.optimizers.Adam(0.0002, 0.5), loss='binary_crossentropy')
Explanation:
- The GAN consists of both generator and discriminator.
- The discriminator is frozen (not trained directly).
- The GAN trains the generator to fool the discriminator.
Step 8: Train the GAN
import os
def train(epochs=10000, batch_size=128, save_interval=1000):
real = np.ones((batch_size, 1)) # Real labels (1s)
fake = np.zeros((batch_size, 1)) # Fake labels (0s)
for epoch in range(epochs):
# Train Discriminator
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_imgs = X_train[idx] # Select random real images
noise = np.random.normal(0, 1, (batch_size, latent_dim)) # Generate noise
fake_imgs = generator.predict(noise) # Generate fake images
d_loss_real = discriminator.train_on_batch(real_imgs, real)
d_loss_fake = discriminator.train_on_batch(fake_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # Average loss
# Train Generator
noise = np.random.normal(0, 1, (batch_size, latent_dim))
g_loss = gan.train_on_batch(noise, real) # Trick the discriminator
# Print progress
if epoch % save_interval == 0:
print(f"{epoch} [D loss: {d_loss[0]:.4f}, Acc: {100 * d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")
save_generated_images(epoch)
def save_generated_images(epoch):
noise = np.random.normal(0, 1, (25, latent_dim))
generated_images = generator.predict(noise)
fig, axs = plt.subplots(5, 5, figsize=(5, 5))
for i in range(5):
for j in range(5):
axs[i, j].imshow(generated_images[i * 5 + j].reshape(28, 28), cmap='gray')
axs[i, j].axis('off')
plt.show()
# Train GAN
train(epochs=10000, batch_size=128, save_interval=1000)
Explanation:
- Real (1s) and fake (0s) labels are created.
- The discriminator trains on real and fake images.
- The generator trains to fool the discriminator.
- The function saves generated images every 1000 epochs.
Key Takeaways
- GANs consist of two neural networks:
- A Generator (creates fake data).
- A Discriminator (detects fake data).
- GANs use adversarial training, where the generator tries to fool the discriminator.
- Loss function matters:
- The discriminator minimizes binary cross-entropy.
- The generator tries to maximize the discriminator’s mistakes.
- Batch Normalization stabilizes GAN training.
- Leaky ReLU prevents neuron death.
- GANs can generate realistic-looking images after training.