KLMinSqrtNaturalGradDescent

Description

This algorithm aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence by running natural gradient descent. KLMinSqrtNaturalGradDescent is a specific implementation of natural gradient variational inference (NGVI) also known as square-root variational Newton[KMKL2025][LDEBTM2024][LDLNKS2023][T2025]. This algorithm operates under the square-root or Cholesky factorization of the covariance matrix parameterization. This contrasts with KLMinNaturalGradDescent, which operates in the precision matrix parameterization, requiring a matrix inverse at each step. As a result, the cost of KLMinSqrtNaturalGradDescent should be relatively cheaper. Since KLMinSqrtNaturalGradDescent is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (FullRankGaussian) that make the updates tractable.

AdvancedVI.KLMinSqrtNaturalGradDescentType
KLMinSqrtNaturalGradDescent(stepsize, n_samples, subsampling)
KLMinSqrtNaturalGradDescent(; stepsize, n_samples, subsampling)

KL divergence minimization algorithm obtained by discretizing the natural gradient flow (the Riemannian gradient flow with the Fisher information matrix as the metric tensor) under the square-root parameterization[KMKL2025][LDENKTM2024][LDLNKS2023][T2025].

This algorithm requires second-order information about the target. If the target LogDensityProblem has second-order differentiation capability, Hessians are used. Otherwise, if the target has only first-order capability, it will use only gradients but this will porbably result in slower convergence and less robust behavior.

(Keyword) Arguments

  • stepsize::Float64: Step size.
  • n_samples::Int: Number of samples used to estimate the natural gradient. (default: 1)
  • subsampling::Union{Nothing,<:AbstractSubsampling}: Optional subsampling strategy.
Note

The subsampling strategy is only applied to the target LogDensityProblem but not to the variational approximation q. That is, KLMinSqrtNaturalGradDescent does not support amortization or structured variational families.

Output

  • q: The last iterate of the algorithm.

Callback Signature

The callback function supplied to optimize needs to have the following signature:

callback(; rng, iteration, q, info)

The keyword arguments are as follows:

  • rng: Random number generator internally used by the algorithm.
  • iteration: The index of the current iteration.
  • q: Current variational approximation.
  • info: NamedTuple containing the information generated during the current iteration.

Requirements

  • The variational family is FullRankGaussian.
  • The target distribution has unconstrained support.
  • The target LogDensityProblems.logdensity(prob, x) has at least first-order differentiation capability.
source

The associated objective value can be estimated through the following:

AdvancedVI.estimate_objectiveMethod
estimate_objective([rng,] alg, q, prob; kwargs...)

Estimate the variational objective to be minimized by the algorithm alg for approximating the target prob with the variational approximation q.

Arguments

  • rng::Random.AbstractRNG: Random number generator.
  • alg::AbstractVariationalAlgorithm: Variational inference algorithm.
  • prob: The target log-joint likelihood implementing the LogDensityProblem interface.
  • q: Variational approximation.

Keyword Arguments

Depending on the algorithm, additional keyword arguments may apply. Please refer to the respective documentation of each algorithm for more info.

Returns

  • obj_est: Estimate of the objective value.
source

Methodology

This algorithm aims to solve the problem

\[ \mathrm{minimize}_{q_{\lambda} \in \mathcal{Q}}\quad \mathrm{KL}\left(q_{\lambda}, \pi\right)\]

where $\mathcal{Q}$ is some family of distributions, often called the variational family, by running stochastic gradient descent in the (Euclidean) space of parameters. That is, for all $q_{\lambda} \in \mathcal{Q}$, we assume $q_{\lambda}$ there is a corresponding vector of parameters $\lambda \in \Lambda$, where the space of parameters is Euclidean such that $\Lambda \subset \mathbb{R}^p$.

Since we usually only have access to the unnormalized densities of the target distribution $\pi$, we don't have direct access to the KL divergence. Instead, the ELBO maximization strategy minimizes a surrogate objective, the negative evidence lower bound[JGJS1999]

\[ \mathcal{L}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right) - \mathbb{H}\left(q\right),\]

which is equivalent to the KL up to an additive constant (the evidence).

While KLMinSqrtNaturalGradDescent is close to a natural gradient variational inference algorithm, it can be derived in a variety of different ways. In fact, the update rule has been concurrently developed by several research groups[KMKL2025][LDEBTM2024][LDLNKS2023][T2025]. Here, we will present the derivation by Kumar et al. [KMKL2025]. Consider the ideal natural gradient descent algorithm discussed here. This can be viewed as a discretization of the continuous-time dynamics given by the differential equation

\[\dot{\lambda}_t = {F(\lambda)}^{-1} \nabla_{\lambda} \mathcal{L}\left(q_{\lambda}\right) .\]

This is also known as the natural gradient flow. Notice that the flow is over the parameters $\lambda_t$. Therefore, the natural gradient flow depends on the way we parametrize $q_{\lambda}$. For Gaussian variational families, if we specifically choose the square-root (or Cholesky) parametrization such that $q_{\lambda_t} = \mathrm{Normal}(m_t, C_t C_t)$, the flow of $\lambda_t = (m_t, C_t)$ given as

\[\begin{align*} \dot{m}_t &= C_t C_t^{\top} \mathbb{E}_{q_{\lambda_t}} \left[ \nabla \log \pi \right] \\ \dot{C}_t &= C_t M\left( \mathrm{I}_d + C_t^{\top} \mathbb{E}\left[ \nabla^2 \log \pi \right] C_t \right) , \end{align*} \]

where $M$ is a $\mathrm{tril}$-like function defined as

\[{[ M(A) ]}_{ij} = \begin{cases} 0 & \text{if $i > j$} \\ \frac{1}{2} A_{ii} & \text{if $i = j$} \\ A_{ij} & \text{if $i < j$} . \end{cases}\]

KLMinSqrtNaturalGradDescent corresponds to the forward Euler discretization of this flow.

  • KMKL2025Kumar, N., Möllenhoff, T., Khan, M. E., & Lucchi, A. (2025). Optimization Guarantees for Square-Root Natural-Gradient Variational Inference. Transactions of Machine Learning Research.
  • LDEBTM2024Lin, W., Dangel, F., Eschenhagen, R., Bae, J., Turner, R. E., & Makhzani, A. (2024). Can We Remove the Square-Root in Adaptive Gradient Methods? A Second-Order Perspective. In International Conference on Machine Learning.
  • LDLNKS2023Lin, W., Duruisseaux, V., Leok, M., Nielsen, F., Khan, M. E., & Schmidt, M. (2023). Simplifying momentum-based positive-definite submanifold optimization with applications to deep learning. In International Conference on Machine Learning.
  • T2025Tan, L. S. (2025). Analytic natural gradient updates for Cholesky factor in Gaussian variational approximation. Journal of the Royal Statistical Society: Series B.
  • JGJS1999Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233.