Variational inference, and in particular its application to variational autoencoders (VAEs; Kingma & Welling, 2013; Rezende et al. 2014), is a powerful tool for learning to approximate complex probability distributions from data.
However, the way it is often introduced initially made it seem somewhat mystical to me.
In this post, I will try to introduce variational inference in a way that shows how it might arise from first principles as a natural solution to learning a neural-network-based latent-variable model.
I will also highlight some of the nuances and subtleties that are often glossed over and motivate why we might want to learn a latent-variable model in the first place compared to other alternatives.
This post will be a bit technical and assumes some familiarity with probability theory and machine learning.
As I have a background in reinforcement learning, I am particularly interested in how variational inference can be used to learn world models of the form . That is the probability of transitioning to state if an agent starts in state and takes action .
To abstract away from this specific use case I will just say we are interested in learning a distribution of the form . A non-reinforcement-learning example could be learning the distribution of images of digits conditional on the particular digit from 0 to 9 the image represents. Often variational inference is introduced as approximating the unconditioned distribution of a single variable , for example, the distribution of images of digits without conditioning on a particular digit.
This is the difference between a conditional VAE (Sohn et al., 2015) and a standard VAE.
In practice, this difference basically just involves an extra input to the neural network representing the conditioning variable, so it won’t add much complexity to the discussion.
Before getting into variational inference specifically, it’s worth unpacking what we mean by learning a distribution .
This could mean a number of different things, in particular, we can distinguish at least three different types of models of we might like to learn which are outlined in the textbook Reinforcement Learning: An Introduction by Sutton and Barto (2020):
- Sample Model: A model that takes in as input and stochastically outputs a single sample with approximately the same distribution as .
- Expectation Model: A model that takes in and outputs an estimate of .
- Distribution Model: A model that takes in as input and outputs some representation of a distribution approximating the full distribution . For example, if is represented as a vector of binary variables, our model could output a vector of values between 0 and 1 representing the probability of each element of being 1 for that particular .
The above definition of a distribution model is somewhat underspecified as I haven’t said how the distribution should be represented (do we have a closed-form expression? Does it suffice to be able to evaluate the probability of any ? If so how computationally expensive is this evaluation?).
Nonetheless, this is a good starting point for a taxonomy of different types of models.
Next, I’ll discuss some potential issues with learning expectation or distribution models before moving on to how latent-variable models can help address these issues.
Expectation models and distribution models
Given a dataset of samples indexed by , we can consider how each of these model types might be implemented and learned.
For concreteness, let’s focus on the problem of modelling binary images of digits given the category .
For an expectation model, we can just implement our model as a neural network represented by which outputs a vector of values between zero and one with one element for every pixel in the image. Let be the output of the neural network with parameters when given as input.
We can train this model by minimizing the pixel-wise mean squared error between the output of the network and the vector corresponding to the pixels of the actual image, that is we can minimize the following loss:
It’s straightforward to show that this loss is minimized when outputs the mean over the dataset of associated with each which is exactly what we want for an expectation model.
However, the expectation is obviously a very limited summary of the distribution.
For example, if we give this model as input, the output will just be a blurry mashup of all the 9s in the dataset. For many downstream tasks, we will want more than this.
Image of the mean 9 from a binary version of MNIST, actually surprisingly recognizable as a 9.
If instead, we want to learn a distribution model, we have to decide on how we are going to represent the distribution.
One simple choice is to take the same neural network , but in this case, interpret the output as the probability that each pixel is 1.
We can then express the estimated probability of any image given a category as
where I’ve used the square bracket notation to index individual elements of a vector, whereas subscripts index dataset elements.
We can then train this model by maximizing the log-likelihood of the data under the model, that is we can minimize the following loss:
which is the negation of the log-likelihood of the data under our model. If we assume our data is drawn from some underlying distribution , then minimizing this loss is equivalent to minimizing an unbiased estimate of the expectation over of the KL divergence between and , that is
note that the first term in the integrand does not depend on and thus does not affect the optimization. KL divergence is always nonnegative, and zero if and only if the two distributions match exactly. It also has intriguing interpretations in terms of the expected number of extra bits required to encode samples from one distribution using a code optimized for the other. Furthermore, we can often bound the utility of the model for downstream applications in terms of the KL divergence between the true distribution and the model distribution. For example, in reinforcement learning, we can bound the expected detriment in return we suffer by optimizing a policy for a learned model instead of the true model in terms of the KL divergence between the two (Ross & Bagnell, 2012). In this sense, KL divergence is a reasonable objective for model learning, though certainly not the only possible objective.
In the specific case where we are dealing with binary vectors, the minima of Equation is actually the same as the minima of the mean squared error loss we used for the expectation model.
Likewise, the resulting model is not much more useful. While technically, we are now representing a distribution, it is a distribution of a specific limited form. In particular, since we represent the probability of each pixel being one independently, we can’t represent any correlations between pixels.
If we sample a particular 9 from the distribution represented by this model it might loosely resemble a 9 but with scattered dots rather than smooth connected lines as each pixel will be sampled as if it belonged to a different 9.
Images sampled according to independent pixel probabilities across all 9s.
At this point, we could try to improve our distribution model by making it more expressive.
However, if we are committed to outputting some representation of the full conditional distribution in closed form, this will probably be an uphill battle.
The most general distribution over binary images is a vector of probabilities for each possible image, elements for a 28x28 image.
Clearly, we don’t want to output a vector of this size, so we would have to make some trade-off between generality and simplicity in the form of our output distribution.
We could spend some time thinking about various restricted distributions that might do a good job of representing reasonable images.
Alternatively, we could give up on trying to output a closed-form expression for the full distribution and instead try to learn something like a sample model.
If we give up on representing a closed-form expression for the distribution, it’s no longer clear whether we can optimize our model by simply minimizing the loss in Equation .
Doing so requires us to evaluate, and differentiate which we can’t do if all our model gives is samples from the distribution.
In the next section, I will discuss autoregressive models which are sort of an intermediate between distribution models and sample models in the sense that they don’t represent the full distribution in any closed form, but we can still efficiently evaluate the probability of any for a given input under the model.
Autoregressive models
For an autoregressive model, we assume each consists of features such that . In our running digits example this could just be the pixels of the image.
Rather than approximating the full joint distribution of , an autoregressive model learns approximate conditional distributions for each feature .
There is an implicit assumption that the individual features have a simple enough form that we can explicitly represent their distribution in closed form (e.g.
a vector of probabilities for each of a finite set of values).
In our running example, we could represent the probability of each pixel being one as a Bernoulli variable.
Example of an image of a 9 being generated by an autoregressive process, each pixel is sampled with probability that depends on the pixels generated so far. In principle, arbitrary distributions over binary images can be represented in this form.
Given an autoregressive model, we can sample by sampling from each conditional distribution sequentially.
It is possible to represent arbitrary distributions in this autoregressive form. Furthermore, in this case we can evaluate for arbitrary .
Thus with an autoregressive model, we can optimize the loss in Equation .
Although sampling must be done sequentially, evaluation of the probability of a particular can be done for all features in parallel.
This is crucial for efficient training of large language models (LLMs) for example.
Although autoregressive models are popular, for example in LLMs, they have some notable drawbacks.
They require features to be sampled sequentially, limiting the potential for parallelism in sampling.
They also require a choice of closed-form distribution to model the individual factors which may be limiting, for example, if the individual factors can take continuous values.
One could commit to a particular choice of distribution, such as a Gaussian, but this would again limit the representation power of the model, as it could not represent multimodal conditional distributions for individual features.
Finally, autoregressive models require a, potentially arbitrary, choice of the order in which features are sampled, which can be unnatural and may result in complicated conditional distributions that are difficult to model.
For example, consider the challenge of autoregressively modelling the distribution of pixels in a complex image in arbitrary order.
To drive home the last point, one can even construct distributions where an autoregressive model faces fundamental computational challenges in representing the conditional distributions while representing the joint distribution of features is easy.
For example, imagine a process which first selects a uniform random integer and then generates using some function which is easy to compute but difficult to invert. Now imagine we want to model where represents the sequence of bits in the integers and concatenated together.
A bitwise autoregressive model of would have to first generate a random output , and then invert in order to accurately model .
On the other hand, a more general model that samples from the joint distribution as a whole would be free to simulate sampling followed by computing , which is easy.
See the work of Lin et al. (2020) for a more detailed discussion of some related issues with autoregressive models.
Example highlighting how autoregressive modeling can be fundamentally more difficult than modeling the full joint distribution of features. The original process being modeled uses a one-way function. An autoregressive model operating in the reverse order is forced to invert the function to accurately capture the conditional distribution.
Latent-variable models
Another class of sample models, which can help to address the challenges outlined so far, is called latent-variable models.
The basic idea is to factor the nontrivial part of the randomness in our learned distribution into a separate noise variable , generally referred to as the latent variable.
We then use a latent-variable-dependent distribution to model the distribution of given and .
Normally, has a simple closed form as a function of .
For example, if is a vector of binary variables, could simply be a neural network which takes and as input and outputs the mean of each element of .
Note that this means the elements of are uncorrelated given .
The latent variable itself will also be drawn from a simple closed form prior distribution conditioned on the input .
For example, could also be represented as a neural network which takes as input and outputs the mean of each element of , where each element of is once again a Bernoulli variable.
I will use vectors of Bernoulli variables as a running example in what follows but other choices, such as vectors of independent Gaussians or categorical variables, are possible and common.
Illustration of sampling process for a latent-variable model. Rather than modeling the distribution of y conditioned on x directly we first sample a latent-variable z, and then sample y conditioned on both x and z. This allows representation of complex distributions of y conditioned on x despite the distributions of z conditioned on x, and y conditioned on x and z having a simple closed-form. The price we pay for this representational power is that we can no longer easily evaluate or optimize the likelihood of the data under our model.
Since the mapping is parameterized by a generic neural network capable of representing complex transformations
can capture complicated dependencies between elements of despite and being simple factored distributions.
In order to draw a sampled from this model, we sample first, then compute , and finally sample .
The basic idea here is quite natural, we wish to represent a potentially complex distribution, and we know that neural networks are universal function approximators, so we can use the neural network to map an arbitrary, sufficiently rich, noise variable to a sample from the distribution we wish to represent.
The next question is: how do we train a latent-variable model from data? We’d like to optimize with respect to the loss in Equation , however, just evaluating in this case requires us to integrate over which in the case where is a vector of Bernoulli variables means summing over combinatorially many values.
We want to be large enough to capture fairly general dependence among the elements of , in which case explicit integration will usually be intractable.
The price we have paid for the generality of using a neural network to model the complexity of the distribution is that we can no longer straightforwardly compute or optimize the log-likelihood of the data under our model.
This is where the idea of variational inference comes in.
Variational inference
Computing requires an intractable integration over , but computing is fairly easy, it’s just a forward pass through a neural network.
If we imagine the ground truth distribution was also generated by first sampling , and we could observe the associated with each sample, we could minimize an empirical estimate of , equivalent to minimizing the following loss:
where is the observable latent variable associated with each sample and, in the final line, I’ve dropped a term which does not depend on .
This is a stronger requirement than only matching the distributions with respect to .
If we manage to make this also guarantees since if the joint distributions over are the same then logically the marginal distribution over must also be the same.
Unfortunately, there is no reason to believe the true distribution is actually generated by first sampling a latent variable , and even if it was, this latent variable is certainly not observable as required by the above objective.
Variational inference works around this issue by augmenting the true distribution with an approximate posterior distribution over , .
Intuitively speaking, we ask: if the actual data was generated by first sampling a latent variable , just like in our model, what is the distribution of values that could have generated the observed given the observed ?
The unknown original generative process from which the data arose, augmented with an approximate posterior. With the addition of the approximate posterior, this process now generates samples in the joint space of y and z, just like our generative model. Unlike the probability of y alone, we can directly evaluate the joint probability of y and z under our generative model, so we can consider performing maximum likelihood in the joint space.
I see this as the essential idea behind variational inference so it’s worth reiterating.
We have a generative model for some variable which generates samples using an intermediate latent variable and we’d like to train this model to generate samples that match the distribution of the data.
If the data also included a latent variable , we could essentially do this by maximum likelihood of the joint distribution over and .
However, the data does not include such a latent variable so instead we augment the data distribution with a conditional distribution over .
Together with the true distribution from which the data was sampled, this leads to a joint distribution which we can now use as a target distribution for our model.
The part of this story that takes a bit of effort to grok, at least for me, is that it doesn’t matter that the true distribution was not generated by first sampling , we can still optimize the joint KL divergence between the true distribution, augmented with an approximate posterior, and our model distribution.
Next, I will explain how this is done more precisely.
Often the approximate posterior is parameterized as a neural network which outputs a distribution with the same form as .
If is a vector of uniform Bernoulli variables, could be parameterized as an equal length vector of Bernoulli variables with means output by a neural network, which now takes both and as input.
We can now consider minimizing an empirical approximation to , let’s start by expanding as follows:
Note that, in the second last line, I have dropped the term which does not depend on and thus is not relevant to its optimization.
Importantly, we optimize Equation with respect to all the parameters of and and , not just the parameters of as we did in Equation .
The final expression in Equation is an expectation over the data distribution of another expectation involving some functions we can evaluate in closed form.
It’s starting to look like something we can actually optimize, but I’ll defer the question of how precisely to do it until a little later.
First, let’s think about whether this is a reasonable thing to do.
Is this a reasonable thing to do?
Even though we are using a learned distribution instead of some ground truth distribution over , we can still say that if we managed to reduce to zero for all this will suffice to ensure that is zero as well, and thus our model matches the data distribution.
In fact, we can actually make a stronger statement than this.
In particular, KL divergence is additive in the following sense:
In our situation, this gives
Thus, by minimizing the left-hand side we are minimizing an upper bound on the right-hand side.
In other words, however low we can make the left-hand side, we can guarantee the right-hand side is even lower.
For this reason the (negation of the) inner expectation in Equation is often called the Evidence Lower Bound (ELBO).
But why should we believe it’s even possible to achieve close to zero given our restrictive choice of approximate posterior? A particular choice of and will correspond to a certain posterior over which we can work out using Bayes theorem:
To push the KL divergence to zero we would require .
Unfortunately, this posterior need not have the factored form we imposed on , i.e.
there may be correlation among the different elements of .
The situation is improved a bit by noting that there need not be one unique which minimizes , there may be many ways to map the latent variable to such that our sample model captures the true distribution.
If there exists some mapping that captures when we marginalize out while also leading to a factored posterior , then that solution will be preferred in terms of Equation .
But do such solutions even exist?
In general, not necessarily, but for sufficiently expressive neural networks, along with a sufficiently rich , it will be possible to make the KL arbitrarily small.
This is essentially analogous to a universal approximation theorem for VAEs, which I will next demonstrate holds, at least for the case of Bernoulli vectors.
If we consider the case where all relevant distributions are Bernoullis, then can take finitely many possible values. In particular, if is a binary vector of length , then can take distinct values.
Let’s index these possible values as .
For realistic data, it is likely that the vast majority of these values will have near zero probability under the true distribution so the effective could be much smaller.
Now consider the extreme case where is a vector of independent Bernoulli variables.
We can then associate each possible value of with an element of by choosing
where is the indicator function. In words, we deterministically map each to the corresponding to the index of the first 1 in . We can then set such that
which ensures that for all .
Furthermore, the true posterior will factor as
and thus can be represented by a factored approximate posterior .
Intuitively, this just says that if we observe a particular generated by the above process, we know that the first elements of must have been zero, and the th element must have been one, but we have no information about the remaining elements of beyond the prior.
Note that this construction requires the prior to be parameterized and conditioned on , it’s an interesting question whether an analogous construction is possible with a fixed prior .
The above construction is not a very efficient way to utilize the elements of our latent variable, requiring a separate element for each possible outcome . There is no doubt much more that could be said about the representational capabilities of latent-variable models trained with variational inference which is beyond the scope of this post, and I am not currently very knowledgeable of the relevant literature.
It’s also worth noting that the existence of such solutions does not mean we can necessarily find them by stochastic gradient descent.
Nevertheless, the construction presented here is a nice sanity check to show that the variational inference approach I’ve described is capable of modelling arbitrary distributions in principle, at least in the Bernoulli case.
I encourage the interested reader to think about how one could utilize the latent variable more efficiently while still maintaining the factored posterior, as well as how one could define similar constructions for other types of latent variable such as vectors of independent Gaussians.
Optimizing the ELBO
Now let’s consider the question of how we can optimize Equation .
We will use stochastic samples of observed from the true unknown distribution, which takes care of the outer expectation, however, we still need to deal with the inner expectation.
We cannot simply sample and optimize the inside of the expectation with respect to the samples as the sampling distribution itself depends on .
However, as long as we can evaluate, sample from, and differentiate , which we guarantee by design, we can use the following general identity to obtain an unbiased gradient estimate:
In our case, we would take and .
We can then derive an unbiased estimator of the gradient by sampling from and optimizing the inside of the expectation for the specific sample.
This estimator is nice in that it’s generic and unbiased, however, it can have high variance to the point that it’s usually not very practical.
Luckily, in many situations, it is possible to derive lower variance unbiased gradient estimates such as the reparameterization trick (Kingma & Welling, 2013; Rezende et al. 2014) for continuous latent-variables. There are plenty of resources on the reparameterization trick, so I will not cover it here. Biased gradient estimators such as the straight-through estimator (Bengio et al., 2013) can also be used for discrete latent-variables and often perform well in practice.
Latent dynamics models for reinforcement learning
I have described how to train a basic (conditional) latent-variable model using variational inference. One could directly apply the described approach to learn a world model for reinforcement learning simply by replacing with the state-action pair and with the following state and having the latent variable represent the noise in the transition dynamics. Often, however, a slightly different setup is used. In particular, rather than using the latent variable only to parameterize transition noise, one can learn a mapping and then model the dynamics themselves in latent space as . An additional learned distribution aims to reconstruct the distribution of next states (see e.g. Watter et al. (2015), Ha and Schmidhuber (2018)). The overall transition distribution is then modelled as
In this case, we are effectively trying to transform states into a new space in which the dynamics obey the factored structure imposed by . This approach has a number of benefits including being able to roll out the model for multiple steps in latent space, without explicitly predicting the environment state in each step. This is essentially the approach of Dreamer (Hafner et. al. 2020; Hafner et. al. 2021; Hafner et. al. 2023) which has demonstrated impressive results for model-based reinforcement learning.
Closing thoughts
The application of variational inference to learn latent-variable models is a rich and active area of research. I have only scratched the surface of the topic here, but the principles presented should apply quite generally. For example, one may wish to use models which include different latent variables at various points in the generative process to capture more complex dynamics and relationships. In such cases, one can still think in terms of minimizing KL-divergence between our generative model and an augmented version of the data distribution which includes approximations to the distribution of each latent variable which is used by our generative model but does not appear in the data. I have found this way of thinking to be a very useful lense to understand what is happening when seeing a particular variational approach for the first time.
There are also a number of alternative techniques for learning probability distributions which strike different trade-offs in terms of ease of sampling, evaluating probabilities, and learning, many of which I have not discussed at all. Some examples include energy-based models, diffusion models, and flow-based models.
My aim in this post was to introduce the concept of variational inference for learning latent-variable models in a way that makes it seem more natural than the way in which I have usually seen it presented. I also emphasized some of the nuances, like the question of whether it’s even possible to make the ELBO arbitrarily small, which I personally consider interesting and important but which are often glossed over. I hope this overview can help others to more rapidly pick up some of the intuitions about variational inference that I have found useful, but which have taken me a long time to develop.
References
Bengio, Y., Léonard, N., & Courville, A. (2013). Estimating or propagating gradients through stochastic neurons for conditional computation.
arXiv preprint arXiv:1308.3432.
Ha, D., & Schmidhuber, J. (2018). World models. Advances in Neural Information Processing Systems.
Hafner, D., Lillicrap, T., Ba, J., & Norouzi, M. (2020). Dream to control:
Learning behaviors by latent imagination. International Conference on Learning Representations.
Hafner, D., Lillicrap, T. P., Norouzi, M., & Ba, J. (2021). Mastering atari with
discrete world models. International Conference on Learning Representations.
Hafner, D., Pasukonis, J., Ba, J., & Lillicrap, T. (2023). Mastering diverse domains through world models. arXiv preprint arXiv:2301.04104.
Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. International Conference on Learning Representations.
Lin, C.-C., Jaech, A., Li, X., Gormley, M. R., & Eisner, J. (2020). Limitations of autoregressive models and their alternatives. arXiv preprint arXiv:2010.11939.
Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). Stochastic backpropagation and approximate inference in deep generative models. International Conference on Machine Learning.
Ross, S., & Bagnell, J. A. (2012). Agnostic system identification for model-based reinforcement learning. International Conference on Machine Learning.
Sohn, K., Yan, X., & Lee, H. (2015). Learning structured output representation using deep conditional generative models. Advances in neural information processing systems.
Sutton, R. S., & Barto, A. G. (2020). Reinforcement learning: An introduction (second edition). MIT Press.
Watter, M., Springenberg, J., Boedecker, J., & Riedmiller, M. (2015). Embed to control: A locally linear latent dynamics model for control from raw images. Advances in neural information processing systems.