Building a Generative Adversarial Network from Scratch

8 minute read

In my previous Article we have learnt how data augmentation can be done using traditional techniques

In this blog post I demonstrate how we can create new images of a distribution of images with a Generative Adversarial Network (GAN)

GAN is a architecture which makes use of multiple neural networks that compete against each other to make the predictions.

fig 1: GAN Architecture

Generator: The network responsible for generating new data from training a data.

Discriminator: Identifies and distinguishes a generated image/fake image from an original image of the training combined toghether form a GAN, both these networks learn based on their previous predictions, competing with each other for the better outcome.

GAN can be implemented with the following steps:

  1. Importing required libraries
  2. Building a simple generator network
  3. Building a simple discriminator
  4. Building a GAN by stacking the Generator and Discriminator
  5. Plotting the Generated images
  6. Training method for GAN
  7. Loading and processing MNIST data
  8. Training the GAN

Step1: Importing the modules

import numpy as np
import matplotlib.pyplot as plt
import keras
from keras.models import Model, Sequential
from keras.datasets import mnist
from tqdm import tqdm
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import Activation, Dense, Dropout, Input

The code blocks mentioned below step 2 and step 3 define two different neural networks. The only major difference between generator and discriminator network are inputs and outputs

  • The Generator networks takes random noise as input and tries to recreate the images from the training set and the

  • The discriminator is binary classifier tries to distinguish the images generated by the generator network from the actual train images.

Step 2: Building a simple Generator Network

def build_generator():
    #initializing the neural network
    generator= Sequential()
    #adding an input layer to the network
    generator.add(Dense(units=256, input_dim=100))
    #activating the layer with LeakyReLU activation function
    generator.add(LeakyReLU(0.2))
    #applying batch Normalization
    generator.add(Dense(units=512))
    #adding the third layer
    generator.add(Dense(units=1024))
    generator.add(LeakyReLU(0.2))
    #the output layer with 784(28x28) nodes
    generator.add(Dense(units=784, activation='tanh'))
    #compiling the generator network with loss and optimizer functions
    generator.compile(loss='binary_crossentropy', optimizer=keras.optimizers.adam(lr=0.0002, beta_1=0.5))
    return generator

Step 3: Building a Simple Discriminator Network

def build_discriminator():
    #initializing a neural network
    discriminator=Sequential()
    #adding an input layer to the network
    discriminator.add(Dense(units=1024,input_dim=784))
    #activating the layer with leakyReLU activation function
    discriminator.add(LeakyReLU(0.2))
    #adding a dropout layer to reduce overfitting
    discriminator.add(Dropout(0.2))
    
    #adding a second layer
    discriminator.add(Dense(units=512))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
    
    #adding a third layer
  
    discriminator.add(Dense(units=256))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
  
    #adding a forth layer
    discriminator.add(Dense(units=128))
    discriminator.add(LeakyReLU(0.2))
  
    #adding the output layer with sigmoid activation
  
    discriminator.add(Dense(units=1,activation='sigmoid'))
  
    #compiling the disciminator Network with a loss and optimizer functions
  
    discriminator.compile(loss='binary_crossentropy',optimizer=keras.optimizers.adam(lr=0.0002,beta_1=0.5))
  
    return discriminator
    

Step 4: Building a GAN Networks

  • This code creates a GAN by stacking the generator and discriminator networks

  • The trainable parameter of the discriminator network when set to false freezes the weights in the discriminator network while the generator network is trained. This prevents the discriminator network from being updated while the generator generates new image from noise.

  • The input shape to the GAN network is the shape of the noise. The noise is fed to the generator and its output is fed to the discriminator which classifies the image as original or generated.

#stacking the generator and discriminator networks to form a GAN.

def gan_net(generator, discriminator):
    #setting the trainable parameter of discriminator or false.
    discriminator.trainable=False
    #instantiates a keras tensor of shape 100 (Noise shape)
    inp=Input(shape=(100,))
    #feeds the output from generator(X) to the discriminator and stores the results in out
    X=generator(inp)
    #feeds the output from generator (X) to the discriminator and stores the results in out
    out=discriminator(X)
    #creates a model include all layers required in the computation of out given inputs
    gan = Model(input=inp, outputs=out)
    #compiling the GAN Network
    gan.compile(loss='binary_crossentropy',optimizer='adam')
    

Step5: Plotting the generated images

The below method generates plots for the images created by the generator form the normally distributed noise input

#method to plot the images


def plot_images(epoch, generator,dim=(10,10),figsize=(10,10)):
    #generate a normally distributed noise of shape (100x100)
    noise=np.random.normal(loc=0,scale=1,size=[100,100])
    #generate an image for the input noise
    noise=np.random.normal(loc=0,scale=1,size=[100,100])
  
    #generate an image for the input noise
  
    generated_images=generator.predict(noise)
  
    #reshape the generated image
    generated_images=generated_images.reshape(100,28,28)
  
    #plot the image
    plt.figure(figsize=f
    #plot for each pixel
  
    for i in range(generated_images.shape[0]):
        
        plt.subplot(dim[0],dim[1],i+1)
        plt.imshow(generated_images[i],cmap='gray',interpolation='nearest')
        plt.axis('off')
        plt.tight_layout()

Step6: Method for training

#Training method with training set, default epoch and default batch_size as arguments.

def train(X_train, epochs=5, batch_size=128):
  
  #initializing the GAN
  generator=build_generator()
  discriminator=build_discriminator()
  gan=gan_net(generator,discriminator)
  
  #training the model for specified epochs
  
  for epoch in range(1,epochs+1):
    print("###### @ Epoch",epoch)
    
    #tqdm module helps to generate a status bar for training
    for _ in tqdm(range(batch_size)):
      
      #random noise with size batch_sizex100
      noise=np.random.normal(0,1,[batch_size,100])
      
      #generating images from noise
      
      generated_images=generator.predict(noise)
      
      #taking random images from the training
      
      image_batch=X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]
      
      #creating a new training set with real and fake images
      
      X=np.concatenate([image_batch,generated_images])
      
      #labels for generated and real data
      y_dis=np.zeros(2*batch_size)
      #label for real images
      y_dis[:batch_size]=1.0
      
      #training the discrminator with real and generated images
      discriminator.trainable=True
      discriminator.train_on_batch(X,y_dis)
      
      #labelling the generated images a sreal images(1) to trick the discriminator
      
      noise=np.random.normal(0,1,[batch_size,100])
      y_gen=np.ones(batch_size)
      
      #freezing the weights of the discriminant or while training generator
      
      discriminator.trainable=False
      
      #training the gan network
      
      gan.train_on_batch(noise, y_gen)
      
      #plotting the images for every 10 epoch
      if epoch==1 or epoch %10==0:
        plot_images(epoch,generator,dim=(10,10),figsize=(15,15))

Step 7 Loading and Processing MNIST Data

#Unpacking the training data from mnist dataset
(X_train,_),(_,_)=mnist.load_data()

#converting to float type and normalizing the data

X_train=(X_train.astype(np.float32)-127.5)/127.5

#convert shape of X_train from (60000,28,28) to (60000, 784) -784 coloumns per row

X_train=X_train.reshape(60000,784)

Step8 Training the GAN

train(X_train,epochs=5,batch_size=128)