Backpropagation on Variational Autoencoders

Once again, online tutorials describe in depth the statistical interpretation of Variational Autoencoders (VAE); however, I find that the implementation of this algorithm is quite different, and similar to that of regular NNs.

The typical vae image online looks like this:

enter image description here

As an enthusiast, I find this explanation very confusing especially in the topic introduction online posts.

Anyways, first let me try to explain how I understand backpropagation on a regular feed-forward neural network.

enter image description here

For example, the chain rule for the derivative of E (total error) with respect to weight w1 is the following:

EW1=EHA1...HA1H1H1w1

Now let’s see the VAE equivalent and calculate the chain rule for the derivative of E (total error) with respect to weight W16 (just an arbitrary weight on the encoder side – they are all the same).

enter image description here

Notice that each weight in the encoder side, including w16, depends on all the connections in the decoder side ;hence, the highlighted connections. The chain rule looks as follows:

Ew16=EOA1OA1O1O1HA4HA4H4H4ZZμμw16+EOA2...+EOA3...+EOA4...

Note that the part in red is the reparameterization trick which I am not going to cover here.

But wait that’s not all – assume for the regular neural network the batch is equal to one – the algorithm goes like this:

  1. Pass the inputs and perform the feed-forward pass.
  2. Calculate the total error and take the derivative for each weight in the network
  3. Update the networks weights and repeat…

However, in VAEs the algorithm is a little different:

  1. Pass the inputs and perform the feed-forward for the encoder and stop.
  2. Sample the latent space (Z) say n-times and perform the feed-forward step with the sampled random variates n-times
  3. Calculate the total error, for all outputs and samples, and take the derivative for each weight in the network
  4. Update the networks weights and repeat…

Okay, okay, yes what is my question!

Question 1

Is my description of the VAE correct?

Question 2

I will try to walk step by step through the sampling of the latent space (Z) and the backprop symbolically.

Let us assume that the VAE input is a one dimensional array (so even if its an image – it has been flattened). Also, the latent space (Z) is one dimensional; hence, it contains one single value for mean (μ) and std.var (σ) assuming the normal distributions.

  • For simplicity, let the error for a single input xi be
    ei=(xi¯xi) where ¯xi is the equivalent vae output.
  • Also, let us assume that there are m inputs and outputs in this vae
    example.
  • Lastly let us assume that mini-batch is one so we update the
    weights after wach backprop; therefore, we will not see the
    mini-batch b index in the gradient formula.

In a regular feed-forward neural net, given the above setup, the total error would look as follows:

E=1mmi=1ei

Therefore from the example above,

Ew1=(1mmi=1ei)w1

and easily update the weight with gradient descent. Very straight forward. Note that we have a single value of each partial derivative i.e.: HA1H1 – this is an important distinction.

Option 1

Now for the VAE, as explained in the online posts, we have to sample n times from the latent space in order to get a good expectation representation.

So given the example and assumptions above, the total error for n samples and m outputs is:

E=1n1mni=imj=1eij

If I understand correctly – we must have at least n samples in order to take the derivative Ew16. Taking the derivative (backprop) in one sample does not make sense.

So, in the VAE the derivative would look as such:

Ew16=(1n1mni=imj=1eij)w16

This means that in the derivative chain we would have to calculate and add the derivatives of a variable or function n times i.e.:

...Z1μ+...+Z2μ+...Znμ

And finally, we update the weight with gradient decent:

wk+116=wk16ηEw16

Option 2

We keep the total error formula the same as in the regular neural network except now we have to index because we are going to end up with n of them:

Ei=1mmj=1ej

and do backprop after each sample of the latent spaze Z but do not update the weights yet:

Eiw16=(1mmj=1ej)w16

where i.e.: now we only have one z-derivative in the chain unlike n in Option 1

...Zμ+...

and finally update the weights by averaging the gradient:

wk+116=wk16ηnni=1Eiw16

So in Question 2 – is Option 1 or Option 2 correct? Am I missing anything?

Thank you so much!

Answer

Q1: Your description seems to be pretty much correct.

Q2: The two options are equal:

Ew=1nni=1Eiw=1nni=1Eiw

Also, note that n=1 is a valid choice:

In our experiments we found that the number of samples L per datapoint can be set to 1 as long as the minibatch size M was large enough, e.g. M=100.

Kingma, Diederik P., and Max Welling. “Auto-encoding variational bayes.” arXiv preprint arXiv:1312.6114 (2013).

Attribution
Source : Link , Question Author : Edv Beq , Answer Author : Jan Kukacka

Leave a Comment