A generative adversarial network (GAN) uses two neural networks, called a generator and discriminator, to generate synthetic data that can convincingly mimic real data. For example, GAN architectures can generate fake, photorealistic pictures of animals or people.
PyTorch is a leading open source deep learning framework. While PyTorch does not provide a built-in implementation of a GAN network, it provides primitives that allow you to build GAN networks, including fully connected neural network layers, convolutional layers, and training functions.
This is part of our series of articles on deep learning for computer vision.
In this article, you will learn:
A generative adversarial network (GAN) uses two neural networks, one known as a “discriminator” and the other known as the “generator”, pitting one against the other.
This is a classifier that analyzes data provided by the generator, and tries to identify if it is fake generated data or real data. Training is performed using real data instances, used as positive examples, and fake data instances from the generator, which are used as negative examples.
The uses a loss function that penalizes a misclassification of a real data instance as fake, or a fake instance as a real one. With every training cycle, the discriminator updates its neural network weights using backpropagation, based on the discriminator loss function, and gets better and better at identifying the fake data instances.
The generator learns to create fake data with feedback from the discriminator. Its goal is to cause the discriminator to classify its output as real.
To train the generator, you’ll need to tightly integrate it with the discriminator. Training involves taking random input, transforming it into a data instance, feeding it to the discriminator and receiving a classification, and computing generator loss, which penalizes for a correct judgement by the discriminator.
Experiments show that the random noise initially fed to the generator can have any distribution—to make things easy, you can use a uniform distribution.
Using the Discriminator to Train the Generator
The process used to train a regular neural network is to modify weights in the backpropagation process, in an attempt to minimize the loss function. However, in a GAN, the generator feeds into the discriminator, and the generator loss measures its failure to fool the discriminator.
This needs to be included in backpropagation—it needs to start at the output and flow back from the discriminator to the generator. It is important to keep the discriminator static during generator training.
This will ensure that with every training cycle, the generator will get a bit better at creating outputs that will fool the current generation of the discriminator.
This brief tutorial is based on the GAN tutorial and code by Nicolas Bertagnolli. We will create a simple generator and discriminator that can generate numbers with 7 binary digits. The real data in this example is valid, even numbers, such as “1,110,010”.
To keep things simple, we’ll build a generator that maps binary digits into seven positions (creating an output like “0100111”). It is sufficient to use one linear layer with sigmoid activation function.
def __init__(self, input_length: int):
self.dense_layer = nn.Linear(int(input_length), int(input_length))
self.activation = nn.Sigmoid()
def forward(self, x):
The discriminator needs to accept the 7-digit input and decide if it belongs to the real data distribution—a valid, even number. We’ll use a logistic regression with a sigmoid activation.
We’ll start training by passing two batches to the model:
The training function looks like this:
Defining optimizers and loss functions:
Now, for each training step, we zero the gradients and create noisy data and true data labels:
We now train the generator. This involves creating random noise, generating “fake” data, getting the discriminator to predict the label of the “fake” data, and calculating discriminator loss using labels as if the data was real.
Backpropagation is performed just for the generator, keeping the discriminator static.
We now update the weights to train the discriminator. This involves passing a batch of true data with “one” labels, then passing data from the generator, with detached weights, and “zero” labels.
Finally, we average the loss functions from two stages, and backpropagate using only the discriminator.
That’s it! The model will now be able to generate convincing 7-digit numbers that are valid, even numbers.
If you want to go beyond this “toy” implementation, and build a full-scale DCGAN with convolutional and convolutional-transpose layers, which can take in images and generate fake, photorealistic images, see the detailed DCGAN tutorial in the PyTorch documentation.
GAN architectures attempt to replicate probability distributions. They use loss functions to measure how far is the data distribution generated by the GAN from the actual distribution the GAN is attempting to mimic.
A GAN typically has two loss functions:
Conditional GANs can train a labeled dataset and assign a label to each created instance. For example, unconditional GAN trained on the MNIST dataset generates random numbers, but conditional MNIST GAN allows you to specify which number the GAN will generate.
In a progressive GAN, the first layer of the generator produces a very low resolution image, and the subsequent layers add detail. This technique makes GAN training faster than non-progressive GANs and can produce high-resolution images.
GAN is a computationally intensive neural network architecture. Run:AI automates resource management and workload orchestration for machine learning infrastructure. With Run:AI, you can automatically run as many compute intensive experiments as needed in PyTorch and other deep learning frameworks.
Here are some of the capabilities you gain when using Run:AI:
Run:AI simplifies machine learning infrastructure pipelines, helping data scientists accelerate their productivity and the quality of their models.
Learn more about the Run:AI GPU virtualization platform.