Generative Adversarial Networks are a powerful class of neural networks with remarkable applications. They essentially consist of a system of two neural networks — the Generator and the Discriminator — dueling each other.
GANs in action. (Source)
Given a set of target samples, the Generator tries to produce samples that can fool the Discriminator into believing they are real. The Discriminator tries to resolve real (target) samples from fake (generated) samples. Using this iterative training approach, we eventually end up with a Generator that is really good at generating samples similar to the target samples.
If this in-depth educational content on generative adversarial networks is useful for you, you can subscribe to our AI research mailing list to be alerted when we release new material.
GANs have a plethora of applications, as they can learn to mimic data distributions of almost any kind. Popularly, GANs are used for removing artefacts, super resolution, pose transfer, and literally any kind of image translation, as shown below:
Image translation using GANs. (Source)
However, they are excruciatingly difficult to work with, owing to its fickle stability. Needless to say, many researchers have proposed brilliant solutions to mitigate some of the problems involved with training GANs. However, the research in this area evolved so fast that, it became hard to keep track of interesting ideas. This blog makes an effort to list out some popular techniques that are commonly used to make GAN training stable.
Drawbacks of using GANs — An Overview
GANs are difficult to work with for a bunch of reasons. Some of them are listed below in this section.
1. Mode collapse
Natural data distributions are highly complex and multimodal. That is, the data distribution has a lot of “peaks” or “modes”. Each mode represents a concentration of similar data samples, but are distinct from other modes.
During mode collapse, the generator produces samples that belong to a limited set of modes. This happens when the generator believes that it can fool the discriminator by locking on to a single mode. That is, the generator produces samples exclusively from this mode.
The image at the top represents the output of a GAN without mode collapse. The image at the bottom represents the output of a GAN with mode collapse. (Source)
The discriminator eventually figures out that samples from this mode are fake. As a result, the generator simply locks on to another mode. This cycle repeats indefinitely, and this essentially limits the diversity of the generated samples. For a more detailed explanation, you can check out this blog.
A common question in GAN training is “when do we stop training them?”. Since the Generator loss improves when the Discriminator loss degrades (and vice-versa), we can not judge convergence based on the value of the loss function. This is illustrated by the image below:
Plot of a typical GAN loss function. Note how convergence cannot be interpreted from this plot. (Source)
As with the previous problem, it is difficult to quantitatively tell when the generator produces high quality samples. Additional perceptual regularization added to the loss function can help mitigate the situation to some extent.
The GAN objective function explains how well the Generator or the Discriminator is performing with respect to its opponent. It does not however represent the quality or the diversity of the output. Hence, we need distinct metrics that can measure the same.
Before we dive deep into techniques that can aid performance, let us review some terminologies. This will simplify explanations of the techniques presented in the next section.
1. Infimum and Supremum
Put simply, Infimum is the largest lower bound of a set. Supremum is the smallest upper bound of a set. They differ from minimum and maximum in the sense that the infimum and supremum need not belong to the set.
2. Divergence Measures
Divergence measures represent the distance between two distributions. Conventional GANs essentially minimize the Jensen Shannon divergence between the real data distribution and the generated data distribution. GAN loss functions can be modified to minimize other divergence measures such as the Kulback Leibler divergence or Total Variation Distance. Popularly, the Wasserstein GAN minimises the Earth Mover distance.
3. Kantorovich Rubenstein Duality
Some divergence measures are intractable to optimize in their naive form. However, their dual form (replacing infimum with supremum or vice-versa) may be tractable to optimize. The duality principle lays a framework for transforming one form to another. For a very detailed explanation about the same, you can check out this blog post.
4. Lipschitz continuity
A Lipschitz continuous function is limited in how fast it can change. For a function to be Lipschitz continuous, the absolute value of the slope of the function’s graph (for any pair of points) cannot be more than a real value K. Such functions are also known as K-Lipschitz continuous.
Lipschitz continuity is desired in GANs as they bound the gradients of the discriminator, essentially preventing the exploding gradient problem.Moreover, the Kantorovich-Rubinstein duality requires it for a Wasserstein GAN, as mentioned in this excellent blog post.
Techniques for Improving Performance
There are a plethora of tricks and techniques that can be used for making GANs more stable and powerful. To keep this blog concise I’ve only explained techniques that are either relatively new or complex. I’ve listed out other miscellaneous tricks and techniques at the end of this section.
1. Alternative Loss Functions
One of the most popular fixes to the shortcomings of GANs is the Wasserstein GAN. It essentially replaces the Jensen Shannon divergence of conventional GANs with the Earth Mover distance (Wasserstein-1 distance or EM distance). The original form of the EM distance is intractable, and hence we use its dual form (calculated by the Kantorovich Rubenstein Duality). This requires the discriminator to be 1-Lipschitz, which is maintained by clipping the weights of the discriminator.
The advantage of using Earth Mover distance is that it is continuous even when the real and generated data distributions are disjoint, unlike JS or KL divergence. Also, there is a correlation between the generated image quality and the loss value (Source). The disadvantage is that, we need to perform several discriminator updates per generator update (as per the original implementation). Moreover, the authors claim that weight clipping is a terrible way to ensure 1-Lipschitz constraint.
The earth mover distance (left) is continuous, even if the distributions are not continuous, unlike the Jensen Shannon divergence (right). Refer to this paper for a detailed explanation.
Another interesting solution is to use mean squared loss instead of log loss. The authors of the LSGAN argue that the conventional GAN loss function does not provide much incentive to “pull” the generated data distribution close to the real data distribution.
The log loss in the original GAN loss function does not bother about the distance of the generated data from the decision boundary (the decision boundary separates real and fake data). LSGAN on the other hand penalizesgenerated samples that are far away from the decision boundary, essentially “pulling” the generated data distribution closer to the real data distribution. It does this by replacing the log loss with mean squared loss. For a detailed explanation of the same, check out this blog.
2. Two Timescale Update Rule (TTUR)
In this method, we use a different learning rate for the discriminator and the generator (Source). Typically, a slower update rule is used for the generator and a faster update rule is used for the discriminator. Using this method, we can perform generator and discriminator updates in 1:1 ratio, and just tinker with the learning rates. Notably, the SAGAN implementation uses this method.
3. Gradient Penalty
In the paper Improved Training of WGANs, the authors claim that weight clipping (as originally performed in WGANs) lead to optimization issues. They claim that weight clipping forces the neural network to learn “simpler approximations” to the optimal data distribution, leading to lower quality results. They also claim that weight clipping leads to the exploding or vanishing gradient problem, if the WGAN hyperparameter is not set properly. The author introduces a simple gradient penalty which is added to the loss function such that the above problems are mitigated. Moreover, 1-Lipschitz continuity is maintained, as in the original WGAN implementation.
Gradient penalty added as regularizer, as in the original WGAN-GP paper. (Source)
The authors of DRAGAN claim that mode collapse occurs when the game played by the GAN (i.e. discriminator and generator going against each other) reaches a “local equilibrium state”. They also claim that the gradients contributed by the discriminator around such states are “sharp”. Naturally, using a gradient penalty will help us circumvent these states, greatly enhancing stability and reducing mode collapse.
4. Spectral Normalization
Spectral normalization is a weight normalization technique that is typically used on the Discriminator to enhance the training process. This essentially ensures that the Discriminator is K-Lipschitz continuous.
5. Unrolling and Packing
As stated in this excellent blog, one way to prevent mode hopping is to peek into the future and anticipate counterplay when updating parameters. Unrolled GANs enables the Generator to fool the Discriminator, after the discriminator had a chance to respond (taking counterplay into account).
Another way of preventing mode collapse is to “pack” several samples belonging to the same class before passing it to the Discriminator. This method is incorporated in PacGAN, in which they have reported decent reduction of mode collapse.
6. Stacking GANs
A single GAN may not be powerful enough to handle a task effectively. We could instead use multiple GANs placed consecutively, where each GAN solves an easier version of the problem. For instance, FashionGAN used two GANs to perform localized image translation.
FashionGAN used two GANs to perform localized image translation. (Source)
Taking this concept to the extreme, we can gradually increase the difficulty of the problem presented to our GANs. For instance, Progressive GANs (ProGANs) can generate high quality images of excellent resolution.
7. Relativistic GANs
Conventional GANs measure the probability of the generated data being real. Relativistic GANs measure the probability of the generated data being “more realistic” than the real data. We can measure this “relative realism” using an appropriate distance measure, as mentioned in the RGAN paper.
Output of the discriminator when using the standard GAN loss (image B). Image C represents how the output curve should actually look like. Image A represents the optimal solution to the JS divergence. (Source)
The authors also mention that the discriminator output should converge to 0.5 when it has reached the optimal state. However, conventional GAN training algorithms force the discriminator to output “real” (i.e. 1) for any image. This, in a way, prevents the discriminator from reaching its optimal value. The relativistic method solves this issue as well, and has pretty remarkable results, as shown below.
Output of a standard GAN (left) and a relativistic GAN (right) after 5000 iterations. (Source)
8. Self Attention Mechanism
The authors of Self Attention GANs claim that convolutions used for generating images look at information that are spread locally. That is, they miss out on relationships that span globally due to their restrictive receptive field.
Adding the attention map (calculated in the yellow box) to the standard convolution operation. (Source)
Self-Attention Generative Adversarial Network allows attention-driven, long-range dependency modeling for image generation tasks. The self-attentionmechanism is complementary to the normal convolution operation. The global information (long range dependencies) aid in generating images of higher quality. The network can choose to ignore the attention mechanism, or consider it along with normal convolutions. For a detailed explanation, you can check out their paper.
Visualization of the attention map for the location marked by the red dot. (Source)
9. Miscellaneous Techniques
Here is a list of some additional techniques (not exhaustive!) that are used to improve GAN training:
- Feature Matching
- Mini Batch Discrimination
- Historical Averaging
- One-sided Label Smoothing
- Virtual Batch Normalization
Now that we have established methods to improve training, we need to quantitatively prove it. The following metrics are often used to measure the performance of a GAN:
1. Inception Score
The inception score measures how “real” the generated data is.
The Inception Score. (Source)
The equation has two components
p(y) . Here,
x is the image that is produced by the Generator, and
p(y|x) is the probability distribution obtained, when you pass image
x through a pre-trained Inception Network (pretrained on the ImageNet dataset, as in the original implementation). Also,
p(y) is the marginal probability distribution, which can be calculated by averaging
p(y|x) over a few distinct samples of generated images (
x). These two terms represent two different qualities that are desirable on real images:
- The generated image must have objects that are “meaningful” (objects are clear, and not blurry). This means that
p(y|x)should have “low entropy”. In other words, our Inception Network must be strongly confident that the generated image belongs to a particular class.
- The generated images should be “diverse”. This means that
p(y)should have “high entropy”. In other words, generator should produce images such that each image represents a different class label (ideally).
Ideal plots of p(y|x) and p(y). Such a pair would have a really large KL divergence. (Source)
If a random variable is highly predictable, it has low entropy (i.e.
p(y|x)must be a distribution with a sharp peak). On the contrary, if it is unpredictable, it has high entropy (i.e.
p(y) must be a uniform distribution). If both these traits are satisfied, we should expect a large KL divergence between
p(y) . Naturally, a large Inception Score (IS) is better. For a deeper analysis on the Inception Score, you can checkout this paper.
2. Fréchet Inception Distance (FID)
A drawback of the Inception Score is that statistics of the real data are not compared with the statistics of the generated data (Source). Fréchet distance resolves the drawback by comparing the mean and covariance of the real and generated images. Fréchet Inception Distance (FID) performs the same analysis, but on the feature maps produced by passing the real and generated images through a pre-trained Inception-v3 Network (Source). The equation is described as follows:
FID compares the mean and covariance of the real and generated data distributions. Tr stands for Trace. (Source)
A lower FID score is better, as it explains that the statistics of the generated images are very similar to that of the real images.
The research community has produced numerous solutions and hacks to overcome the shortcomings of GAN training. However, it is difficult to keep track of significant contributions due to the sheer volume of new research. The details shared in this blog is not exhaustive for the same reason, and may become outdated in the near future. Nevertheless, I hope this blog serves as a guideline for people looking for methods to improve the performance of their GANs.
BeyondMinds is an AI research company which is both a problem-solver and incubator in AI. We are passionate about creating immense value through AI research and incubation, making AI more accessible, usable and safe for the world. We are a team of 20+ ML/DL researchers based in Tel Aviv. We have research partnerships with market-leading organizations such as Microsoft, KPMG and Rafael.
This article was originally published on Medium and re-published to TOPBOTS with permission from the author.
Enjoy this article? Sign up for more image generation updates.
We’ll let you know when we release more technical education.