Deep Learning February 02 ,2025

Step-by-Step Implementation of Generative Adversarial Networks (GANs) in Python

Objective:

We will implement a Generative Adversarial Network (GAN) using TensorFlow and Keras to generate handwritten digits similar to the MNIST dataset.

Step 1: Install Dependencies

Ensure you have the required libraries installed. If not, install them using:

pip install tensorflow numpy matplotlib

Step 2: Import Libraries

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Flatten
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

Explanation:

  • tensorflow and keras are used for building the GAN architecture.
  • Dense is used to define fully connected layers.
  • LeakyReLU helps prevent dying neurons in GANs.
  • BatchNormalization stabilizes training.
  • Reshape and Flatten help convert data between image and vector form.
  • mnist dataset contains handwritten digits (0-9).

Step 3: Load and Preprocess the Data

# Load MNIST dataset
(X_train, _), (_, _) = mnist.load_data()

# Normalize data to [-1, 1] for better GAN performance
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=-1)  # Add a channel dimension

# Define input shape
img_shape = (28, 28, 1)
latent_dim = 100  # Size of random noise vector

Explanation:

  • The MNIST dataset consists of grayscale images of digits (28×28 pixels).
  • Pixel values are normalized to [-1, 1] for better convergence.
  • latent_dim=100 means we will use a 100-dimensional noise vector to generate images.

Step 4: Build the Generator

def build_generator():
    model = Sequential([
        Dense(256, input_dim=latent_dim),
        LeakyReLU(alpha=0.2),
        BatchNormalization(),
        
        Dense(512),
        LeakyReLU(alpha=0.2),
        BatchNormalization(),
        
        Dense(1024),
        LeakyReLU(alpha=0.2),
        BatchNormalization(),
        
        Dense(28 * 28 * 1, activation='tanh'),
        Reshape((28, 28, 1))
    ])
    return model

generator = build_generator()
generator.summary()

Explanation:

  • The generator takes random noise (100 values) as input.
  • Dense(256) → Dense(512) → Dense(1024): Expands the latent space.
  • LeakyReLU(alpha=0.2): Prevents neurons from becoming inactive.
  • BatchNormalization(): Stabilizes training.
  • Dense(28×28×1) → Reshape((28,28,1)): Converts the output into a 28×28 grayscale image.

Step 5: Build the Discriminator

def build_discriminator():
    model = Sequential([
        Flatten(input_shape=img_shape),
        
        Dense(512),
        LeakyReLU(alpha=0.2),
        
        Dense(256),
        LeakyReLU(alpha=0.2),
        
        Dense(1, activation='sigmoid')  # Output probability (real or fake)
    ])
    return model

discriminator = build_discriminator()
discriminator.summary()

Explanation:

  • The discriminator takes an image as input and classifies it as real or fake.
  • Flatten(): Converts the 28×28 image into a 1D vector.
  • Dense(512) → Dense(256): Extracts features.
  • LeakyReLU(alpha=0.2): Prevents neuron death.
  • Dense(1, activation='sigmoid'): Outputs 0 (fake) or 1 (real).

Step 6: Compile the Discriminator

discriminator.compile(optimizer=keras.optimizers.Adam(0.0002, 0.5), loss='binary_crossentropy', metrics=['accuracy'])

Explanation:

  • The discriminator uses binary cross-entropy loss for real vs. fake classification.
  • The Adam optimizer ensures smooth training.

Step 7: Build and Compile the GAN

# Freeze discriminator during GAN training
discriminator.trainable = False

# Build GAN by stacking generator and discriminator
gan = Sequential([generator, discriminator])
gan.compile(optimizer=keras.optimizers.Adam(0.0002, 0.5), loss='binary_crossentropy')

Explanation:

  • The GAN consists of both generator and discriminator.
  • The discriminator is frozen (not trained directly).
  • The GAN trains the generator to fool the discriminator.

Step 8: Train the GAN

import os

def train(epochs=10000, batch_size=128, save_interval=1000):
    real = np.ones((batch_size, 1))  # Real labels (1s)
    fake = np.zeros((batch_size, 1))  # Fake labels (0s)

    for epoch in range(epochs):
        # Train Discriminator
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_imgs = X_train[idx]  # Select random real images

        noise = np.random.normal(0, 1, (batch_size, latent_dim))  # Generate noise
        fake_imgs = generator.predict(noise)  # Generate fake images

        d_loss_real = discriminator.train_on_batch(real_imgs, real)
        d_loss_fake = discriminator.train_on_batch(fake_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)  # Average loss

        # Train Generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = gan.train_on_batch(noise, real)  # Trick the discriminator

        # Print progress
        if epoch % save_interval == 0:
            print(f"{epoch} [D loss: {d_loss[0]:.4f}, Acc: {100 * d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")
            save_generated_images(epoch)

def save_generated_images(epoch):
    noise = np.random.normal(0, 1, (25, latent_dim))
    generated_images = generator.predict(noise)

    fig, axs = plt.subplots(5, 5, figsize=(5, 5))
    for i in range(5):
        for j in range(5):
            axs[i, j].imshow(generated_images[i * 5 + j].reshape(28, 28), cmap='gray')
            axs[i, j].axis('off')
    plt.show()

# Train GAN
train(epochs=10000, batch_size=128, save_interval=1000)

Explanation:

  • Real (1s) and fake (0s) labels are created.
  • The discriminator trains on real and fake images.
  • The generator trains to fool the discriminator.
  • The function saves generated images every 1000 epochs.

Key Takeaways

  1. GANs consist of two neural networks:
    • A Generator (creates fake data).
    • A Discriminator (detects fake data).
  2. GANs use adversarial training, where the generator tries to fool the discriminator.
  3. Loss function matters:
    • The discriminator minimizes binary cross-entropy.
    • The generator tries to maximize the discriminator’s mistakes.
  4. Batch Normalization stabilizes GAN training.
  5. Leaky ReLU prevents neuron death.
  6. GANs can generate realistic-looking images after training.

 

Purnima
0

You must logged in to post comments.

Get In Touch

123 Street, New York, USA

+012 345 67890

techiefreak87@gmail.com

© Design & Developed by HW Infotech