KLMinNaturalGradDescent
Description
This algorithm aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence by running natural gradient descent. KLMinNaturalGradDescent is a specific implementation of natural gradient variational inference (NGVI) also known as variational online Newton[KR2023]. For nearly-Gaussian targets, NGVI tends to converge very quickly. If the ensure_posdef option is set to true (this is the default configuration), then the update rule of [LSK2020] is used, which guarantees the updated precision matrix is always positive definite. Since KLMinNaturalGradDescent is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (FullRankGaussian) that make the updates tractable.
AdvancedVI.KLMinNaturalGradDescent — Type
KLMinNaturalGradDescent(stepsize, n_samples, ensure_posdef, subsampling)
KLMinNaturalGradDescent(; stepsize, n_samples, ensure_posdef, subsampling)KL divergence minimization by running natural gradient descent[KL2017][KR2023], also called variational online Newton. This algorithm can be viewed as an instantiation of mirror descent, where the Bregman divergence is chosen to be the KL divergence.
If the ensure_posdef argument is true, the algorithm applies the technique by Lin et al.[LSK2020], where the precision matrix update includes an additional term that guarantees positive definiteness. This, however, involves an additional set of matrix-matrix system solves that could be costly.
This algorithm requires second-order information about the target. If the target LogDensityProblem has second-order differentiation capability, Hessians are used. Otherwise, if the target has only first-order capability, it will use only gradients but this will porbably result in slower convergence and less robust behavior.
(Keyword) Arguments
stepsize::Float64: Step size.n_samples::Int: Number of samples used to estimate the natural gradient. (default:1)ensure_posdef::Bool: Ensure that the updated precision preserves positive definiteness. (default:true)subsampling::Union{Nothing,<:AbstractSubsampling}: Optional subsampling strategy.
The subsampling strategy is only applied to the target LogDensityProblem but not to the variational approximation q. That is, KLMinNaturalGradDescent does not support amortization or structured variational families.
Output
q: The last iterate of the algorithm.
Callback Signature
The callback function supplied to optimize needs to have the following signature:
callback(; rng, iteration, q, info)The keyword arguments are as follows:
rng: Random number generator internally used by the algorithm.iteration: The index of the current iteration.q: Current variational approximation.info:NamedTuplecontaining the information generated during the current iteration.
Requirements
- The variational family is
FullRankGaussian. - The target distribution has unconstrained support.
- The target
LogDensityProblems.logdensity(prob, x)has at least first-order differentiation capability.
The associated objective can be estimated through the following:
AdvancedVI.estimate_objective — Method
estimate_objective([rng,] alg, q, prob; kwargs...)Estimate the variational objective to be minimized by the algorithm alg for approximating the target prob with the variational approximation q.
Arguments
rng::Random.AbstractRNG: Random number generator.alg::AbstractVariationalAlgorithm: Variational inference algorithm.prob: The target log-joint likelihood implementing theLogDensityProbleminterface.q: Variational approximation.
Keyword Arguments
Depending on the algorithm, additional keyword arguments may apply. Please refer to the respective documentation of each algorithm for more info.
Returns
obj_est: Estimate of the objective value.
Methodology
This algorithm aims to solve the problem
\[ \mathrm{minimize}_{q_{\lambda} \in \mathcal{Q}}\quad \mathrm{KL}\left(q_{\lambda}, \pi\right)\]
where $\mathcal{Q}$ is some family of distributions, often called the variational family, by running stochastic gradient descent in the (Euclidean) space of parameters. That is, for all $q_{\lambda} \in \mathcal{Q}$, we assume $q_{\lambda}$ there is a corresponding vector of parameters $\lambda \in \Lambda$, where the space of parameters is Euclidean such that $\Lambda \subset \mathbb{R}^p$.
Since we usually only have access to the unnormalized densities of the target distribution $\pi$, we don't have direct access to the KL divergence. Instead, the ELBO maximization strategy minimizes a surrogate objective, the negative evidence lower bound[JGJS1999]
\[ \mathcal{L}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right) - \mathbb{H}\left(q\right),\]
which is equivalent to the KL up to an additive constant (the evidence).
Suppose we had access to the exact gradients $\nabla_{\lambda} \mathcal{L}\left(q_{\lambda}\right)$. NGVI attempts to minimize $\mathcal{L}$ via natural gradient descent, which corresponds to iterating the mirror descent update
\[\lambda_{t+1} = \argmin_{\lambda \in \Lambda} {\langle \nabla_{\lambda} \mathcal{L}\left(q_{\lambda_t}\right), \lambda - \lambda_t \rangle} + \frac{1}{2 \gamma_t} \mathrm{KL}\left(q, q_{\lambda_t}\right) .\]
This turns out to be equivalent to the update
\[\lambda_{t+1} = \lambda_{t} - \gamma_t {F(\lambda_t)}^{-1} \nabla_{\lambda} \mathcal{L}(q_{\lambda_t}) ,\]
where $F(\lambda_t)$ is the Fisher information matrix of $q_{\lambda}$. That is, natural gradient descent can be viewed as gradient descent with an iterate-dependent preconditioning. Furthermore, ${F(\lambda_t)}^{-1} \nabla_{\lambda} \mathcal{L}(q_{\lambda_t})$ is refered to as the natural gradient of the KL divergence[A1998], hence natural gradient variational inference. Also note that the gradient is taken over the parameters of $q_{\lambda}$. Therefore, NGVI is parametrization dependent: for the same variational family, different parametrizations will result in different behavior. However, the pseudo-metric $\mathrm{KL}\left(q, q_{\lambda_t}\right)$ is over measures. Therefore, NGVI tend to behave as a measure-space algorithm, but technically speaking, not a fully measure-space algorithm.
In practice, we don't have access to $\nabla_{\lambda} \mathcal{L}\left(q_{\lambda}\right)$ apart from its unbiased estimate. Regardless, the natural gradient descent/mirror descent updates involving the stochastic estimates have been derived for some variational families. For instance, Gaussian variational families[KR2023] and mixture of exponential families[LKS2019]. As of now, we only implement the Gaussian version.
- KR2023Khan, M. E., & Rue, H. (2023). The Bayesian learning rule. Journal of Machine Learning Research, 24(281), 1-46.
- LSK2020Lin, W., Schmidt, M., & Khan, M. E. (2020). Handling the positive-definite constraint in the Bayesian learning rule. In International Conference on Machine Learning. PMLR.
- LKS2019Lin, W., Khan, M. E., & Schmidt, M. (2019). Fast and simple natural-gradient variational inference with mixture of exponential-family approximations. In International Conference on Machine Learning. PMLR.
- A1998Amari, S. I. (1998). Natural gradient works efficiently in learning. Neural computation, 10(2), 251-276.
- JGJS1999Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233.