Building a Generative Adversarial Network from Scratch
In my previous Article we have learnt how data augmentation can be done using traditional techniques
In this blog post I demonstrate how we can create new images of a distribution of images with a Generative Adversarial Network (GAN)
GAN is a architecture which makes use of multiple neural networks that compete against each other to make the predictions.
Generator: The network responsible for generating new data from training a data.
Discriminator: Identifies and distinguishes a generated image/fake image from an original image of the training combined toghether form a GAN, both these networks learn based on their previous predictions, competing with each other for the better outcome.
GAN can be implemented with the following steps:
- Importing required libraries
- Building a simple generator network
- Building a simple discriminator
- Building a GAN by stacking the Generator and Discriminator
- Plotting the Generated images
- Training method for GAN
- Loading and processing MNIST data
- Training the GAN
Step1: Importing the modules
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras.models import Model, Sequential
from keras.datasets import mnist
from tqdm import tqdm
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import Activation, Dense, Dropout, Input
The code blocks mentioned below step 2 and step 3 define two different neural networks. The only major difference between generator and discriminator network are inputs and outputs
-
The Generator networks takes random noise as input and tries to recreate the images from the training set and the
-
The discriminator is binary classifier tries to distinguish the images generated by the generator network from the actual train images.
Step 2: Building a simple Generator Network
def build_generator():
#initializing the neural network
generator= Sequential()
#adding an input layer to the network
generator.add(Dense(units=256, input_dim=100))
#activating the layer with LeakyReLU activation function
generator.add(LeakyReLU(0.2))
#applying batch Normalization
generator.add(Dense(units=512))
#adding the third layer
generator.add(Dense(units=1024))
generator.add(LeakyReLU(0.2))
#the output layer with 784(28x28) nodes
generator.add(Dense(units=784, activation='tanh'))
#compiling the generator network with loss and optimizer functions
generator.compile(loss='binary_crossentropy', optimizer=keras.optimizers.adam(lr=0.0002, beta_1=0.5))
return generator
Step 3: Building a Simple Discriminator Network
def build_discriminator():
#initializing a neural network
discriminator=Sequential()
#adding an input layer to the network
discriminator.add(Dense(units=1024,input_dim=784))
#activating the layer with leakyReLU activation function
discriminator.add(LeakyReLU(0.2))
#adding a dropout layer to reduce overfitting
discriminator.add(Dropout(0.2))
#adding a second layer
discriminator.add(Dense(units=512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
#adding a third layer
discriminator.add(Dense(units=256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
#adding a forth layer
discriminator.add(Dense(units=128))
discriminator.add(LeakyReLU(0.2))
#adding the output layer with sigmoid activation
discriminator.add(Dense(units=1,activation='sigmoid'))
#compiling the disciminator Network with a loss and optimizer functions
discriminator.compile(loss='binary_crossentropy',optimizer=keras.optimizers.adam(lr=0.0002,beta_1=0.5))
return discriminator
Step 4: Building a GAN Networks
-
This code creates a GAN by stacking the generator and discriminator networks
-
The trainable parameter of the discriminator network when set to false freezes the weights in the discriminator network while the generator network is trained. This prevents the discriminator network from being updated while the generator generates new image from noise.
-
The input shape to the GAN network is the shape of the noise. The noise is fed to the generator and its output is fed to the discriminator which classifies the image as original or generated.
#stacking the generator and discriminator networks to form a GAN.
def gan_net(generator, discriminator):
#setting the trainable parameter of discriminator or false.
discriminator.trainable=False
#instantiates a keras tensor of shape 100 (Noise shape)
inp=Input(shape=(100,))
#feeds the output from generator(X) to the discriminator and stores the results in out
X=generator(inp)
#feeds the output from generator (X) to the discriminator and stores the results in out
out=discriminator(X)
#creates a model include all layers required in the computation of out given inputs
gan = Model(input=inp, outputs=out)
#compiling the GAN Network
gan.compile(loss='binary_crossentropy',optimizer='adam')
Step5: Plotting the generated images
The below method generates plots for the images created by the generator form the normally distributed noise input
#method to plot the images
def plot_images(epoch, generator,dim=(10,10),figsize=(10,10)):
#generate a normally distributed noise of shape (100x100)
noise=np.random.normal(loc=0,scale=1,size=[100,100])
#generate an image for the input noise
noise=np.random.normal(loc=0,scale=1,size=[100,100])
#generate an image for the input noise
generated_images=generator.predict(noise)
#reshape the generated image
generated_images=generated_images.reshape(100,28,28)
#plot the image
plt.figure(figsize=f
#plot for each pixel
for i in range(generated_images.shape[0]):
plt.subplot(dim[0],dim[1],i+1)
plt.imshow(generated_images[i],cmap='gray',interpolation='nearest')
plt.axis('off')
plt.tight_layout()
Step6: Method for training
#Training method with training set, default epoch and default batch_size as arguments.
def train(X_train, epochs=5, batch_size=128):
#initializing the GAN
generator=build_generator()
discriminator=build_discriminator()
gan=gan_net(generator,discriminator)
#training the model for specified epochs
for epoch in range(1,epochs+1):
print("###### @ Epoch",epoch)
#tqdm module helps to generate a status bar for training
for _ in tqdm(range(batch_size)):
#random noise with size batch_sizex100
noise=np.random.normal(0,1,[batch_size,100])
#generating images from noise
generated_images=generator.predict(noise)
#taking random images from the training
image_batch=X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]
#creating a new training set with real and fake images
X=np.concatenate([image_batch,generated_images])
#labels for generated and real data
y_dis=np.zeros(2*batch_size)
#label for real images
y_dis[:batch_size]=1.0
#training the discrminator with real and generated images
discriminator.trainable=True
discriminator.train_on_batch(X,y_dis)
#labelling the generated images a sreal images(1) to trick the discriminator
noise=np.random.normal(0,1,[batch_size,100])
y_gen=np.ones(batch_size)
#freezing the weights of the discriminant or while training generator
discriminator.trainable=False
#training the gan network
gan.train_on_batch(noise, y_gen)
#plotting the images for every 10 epoch
if epoch==1 or epoch %10==0:
plot_images(epoch,generator,dim=(10,10),figsize=(15,15))
Step 7 Loading and Processing MNIST Data
#Unpacking the training data from mnist dataset
(X_train,_),(_,_)=mnist.load_data()
#converting to float type and normalizing the data
X_train=(X_train.astype(np.float32)-127.5)/127.5
#convert shape of X_train from (60000,28,28) to (60000, 784) -784 coloumns per row
X_train=X_train.reshape(60000,784)
Step8 Training the GAN
train(X_train,epochs=5,batch_size=128)