Once again, online tutorials describe in depth the statistical interpretation of Variational Autoencoders (VAE); however, I find that the implementation of this algorithm is quite different, and similar to that of regular NNs.
The typical vae image online looks like this:
As an enthusiast, I find this explanation very confusing especially in the topic introduction online posts.
Anyways, first let me try to explain how I understand backpropagation on a regular feed-forward neural network.
For example, the chain rule for the derivative of E (total error) with respect to weight w1 is the following:
Now let’s see the VAE equivalent and calculate the chain rule for the derivative of E (total error) with respect to weight W16 (just an arbitrary weight on the encoder side – they are all the same).
Notice that each weight in the encoder side, including w16, depends on all the connections in the decoder side ;hence, the highlighted connections. The chain rule looks as follows:
Note that the part in red is the reparameterization trick which I am not going to cover here.
But wait that’s not all – assume for the regular neural network the batch is equal to one – the algorithm goes like this:
- Pass the inputs and perform the feed-forward pass.
- Calculate the total error and take the derivative for each weight in the network
- Update the networks weights and repeat…
However, in VAEs the algorithm is a little different:
- Pass the inputs and perform the feed-forward for the encoder and stop.
- Sample the latent space (Z) say n-times and perform the feed-forward step with the sampled random variates n-times
- Calculate the total error, for all outputs and samples, and take the derivative for each weight in the network
- Update the networks weights and repeat…
Okay, okay, yes what is my question!
Is my description of the VAE correct?
I will try to walk step by step through the sampling of the latent space (Z) and the backprop symbolically.
Let us assume that the VAE input is a one dimensional array (so even if its an image – it has been flattened). Also, the latent space (Z) is one dimensional; hence, it contains one single value for mean (μ) and std.var (σ) assuming the normal distributions.
- For simplicity, let the error for a single input xi be
ei=(xi−¯xi) where ¯xi is the equivalent vae output.
- Also, let us assume that there are m inputs and outputs in this vae
- Lastly let us assume that mini-batch is one so we update the
weights after wach backprop; therefore, we will not see the
mini-batch b index in the gradient formula.
In a regular feed-forward neural net, given the above setup, the total error would look as follows:
Therefore from the example above,
and easily update the weight with gradient descent. Very straight forward. Note that we have a single value of each partial derivative i.e.: ∂HA1∂H1 – this is an important distinction.
Now for the VAE, as explained in the online posts, we have to sample n times from the latent space in order to get a good expectation representation.
So given the example and assumptions above, the total error for n samples and m outputs is:
If I understand correctly – we must have at least n samples in order to take the derivative ∂E∂w16. Taking the derivative (backprop) in one sample does not make sense.
So, in the VAE the derivative would look as such:
This means that in the derivative chain we would have to calculate and add the derivatives of a variable or function n times i.e.:
And finally, we update the weight with gradient decent:
We keep the total error formula the same as in the regular neural network except now we have to index because we are going to end up with n of them:
and do backprop after each sample of the latent spaze Z but do not update the weights yet:
where i.e.: now we only have one z-derivative in the chain unlike n in Option 1
and finally update the weights by averaging the gradient:
So in Question 2 – is Option 1 or Option 2 correct? Am I missing anything?
Thank you so much!
Q1: Your description seems to be pretty much correct.
Q2: The two options are equal:
Also, note that n=1 is a valid choice:
In our experiments we found that the number of samples L per datapoint can be set to 1 as long as the minibatch size M was large enough, e.g. M=100.
Kingma, Diederik P., and Max Welling. “Auto-encoding variational bayes.” arXiv preprint arXiv:1312.6114 (2013).