What’s wrong with you GANs ?
Or how to keep out of troubles when working with GANs
Well my fellow reader, if you are reading this post, then it means that you are already familiar with the concept of GAN. And that’s good, because I am not gonna explain here some details, like architecture or math behind it. Instead, I wanna share my experience with thorny training process of these neural networks.
So, where should I begin ? … Oh yes, maybe describing my project would be good start. Currently, I am working on my diploma thesis, which is focused on various use-cases, where GAN and 5G networks can cooperate and create a better place for a living (theoretically). I don’t wanna write here about details of the thesis, but I can tell you that I was using pix2pix GAN architecture (for more info check this link), which belongs to the category of conditional GANs.
The first step of course was to code overall architecture (I used pytorch for this) and prepare data for training. Everything was going on pretty well, until I started to do experiments, which also means training the networks.
And here are 3 problems I was struggling with:
- Mode collapse
- Slow convergence
- Deceptive loss function
I guess you are pretty curious about these 3 guys already (that’s why you are here huh ?), so let’s bring some details about them.
Mode collapse
Probably the most common issue with GANs training is mode collapse and trust me, it can be really hard to resolve this issue. You can observe it when your generator is producing small amount of really similar outputs, which for discriminator seems OK, but in reality it’s trash. Core of this issue can be explained with exhaustive math, but to keep this post simple I will use my own words.
The main point is that the generator somehow finds local optimum in samples space, which at given point of time looks good for discriminator. This is often happening, when the discriminator is not already properly trained to recognize what is correct and what is incorrect (the whole architecture is collapsing to point, where is no way to return).
As the result of this collapsing your generator is always producing very similar small subset of samples to decrease its loss and the discriminator is not improving at all, so there is no way to punish the generator and force him to produce something new and better.
There is no obvious and perfect solution for this issue and it can vary between projects, but there are some the most common tricks (for me the last one was working):
1. Adding layers to your generator/removing layers from discriminator
2. Try to train your GAN longer
3. Perform some parameter tuning on both networks
4. Add some dropout and batch normalization layers to the generator
Deceptive loss function
Despite solving the previous trouble, I was still having a feeling that something is wrong. We all know that a loss function should be something very useful, when it comes to the training of neural networks, but in my case and generally in case of GANs you must be aware of what the loss function is trying to say to you. Let’s explain it further.
After some epochs program shows me plots of loss functions for generator and discriminator. At first glance it was looking pretty solid, both curves had descending shape and there was no sign of dominance between networks. Do you still remember what I said about loss functions ? Good because here it comes, the plot twist.
As the next step, I checked generated images and approximately after 10 000 epochs, results was really bad and far from my expectations. I realized that loss function cannot be the best metric to measure progress of the training and after some google magic I found that it is recommended to use additional (even custom) metrics to keep track of GAN performance. Blessed by this knowledge I used well-known metric in computer vision tasks (SSIM) and also my own custom metric. And guess what, it really helps me to detect errors in early stages of the training.
Slow convergence
I can imagine that now you must be like, “Wow, so much pain there must be some happy ending“. Well my friend, there is but first let me introduce you term slow convergence.
In previous section I mentioned that my program was running for 10 000 epochs. It approximately took my GPU (Nvidia RTX 2070 Super) 3 hours, which is pretty lot of time. But after some experiments, I found out that to generate images of required quality, it was needed to run program for almost 50 000 epochs. To reach this number of epochs, my GPU was forced to run over 21 hours.
My recommendation after this observations is that even your model is performing poor from the beginning, you should always try to run it for a little longer. Other solution for slow convergence and poor results can be removing some inner layers from your generator. You can also try this process on discriminator. And my second recommendation is that if you really wanna play with GANs, you should buy some GPU with decent amount of CUDA cores or it will take eternity to get some results.
AND THAT’S IT, you read it all and now you are ready to fight these obstacles in your projects. For more posts like this don’t forget to subscribe, share or hit the clap button. See you again.