KLMinRepGradProxDescent

Description

This algorithm is a slight variation of KLMinRepGradDescent specialized to location-scale families. Therefore, it also aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence over the space of parameters. But instead, it uses stochastic proximal gradient descent with the proximal operator of the entropy of location-scale variational families as discussed in: [D2020][KMG2024][DGG2023]. The remainder of the section will only discuss details specific to KLMinRepGradProxDescent. Thus, for general usage and additional details, please refer to the docs of KLMinRepGradDescent instead.

AdvancedVI.KLMinRepGradProxDescentType
KLMinRepGradProxDescent(adtype; entropy_zerograd, optimizer, n_samples, averager)

KL divergence minimization by running stochastic proximal gradient descent with the reparameterization gradient in the Euclidean space of variational parameters of a location-scale family.

This algorithm only supports subtypes of MvLocationScale. Also, since the stochastic proximal gradient descent does not use the entropy of the gradient, the entropy estimator to be used must have a zero-mean gradient. Thus, only the entropy estimators with a "ZeroGradient" suffix are allowed.

Arguments

  • adtype: Automatic differentiation backend.

Keyword Arguments

  • entropy_zerograd: Estimator of the entropy with a zero-mean gradient to be used. Must be one of ClosedFormEntropyZeroGrad, StickingTheLandingEntropyZeroGrad. (default: ClosedFormEntropyZeroGrad())
  • optimizer::Optimisers.AbstractRule: Optimization algorithm to be used. Only DoG, DoWG and Optimisers.Descent are supported. (default: DoWG())
  • n_samples::Int: Number of Monte Carlo samples to be used for estimating each gradient.
  • averager::AbstractAverager: Parameter averaging strategy. (default: PolynomialAveraging())
  • subsampling::Union{<:Nothing,<:AbstractSubsampling}: Data point subsampling strategy. If nothing, subsampling is not used. (default: nothing)

Output

  • q_averaged: The variational approximation formed by the averaged SGD iterates.

Callback

The callback function callback has a signature of

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:

  • rng: Random number generator internally used by the algorithm.
  • iteration: The index of the current iteration.
  • restructure: Function that restructures the variational approximation from the variational parameters. Calling restructure(params) reconstructs the current variational approximation.
  • params: Current variational parameters.
  • averaged_params: Variational parameters averaged according to the averaging strategy.
  • gradient: The estimated (possibly stochastic) gradient.

Requirements

  • The variational family is MvLocationScale.
  • The target distribution and the variational approximation have the same support.
  • The target LogDensityProblems.logdensity(prob, x) must be differentiable with respect to x by the selected AD backend.
  • Additonal requirements on q may apply depending on the choice of entropy_zerograd.
source

The associated objective value can be estimated through the following:

AdvancedVI.estimate_objectiveMethod
estimate_objective([rng,] alg, q, prob; n_samples, entropy)

Estimate the ELBO of the variational approximation q against the target log-density prob.

Arguments

  • rng::Random.AbstractRNG: Random number generator.
  • alg::Union{<:KLMinRepGradDescent,<:KLMinRepGradProxDescent,<:KLMinScoreGradDescent}: Variational inference algorithm.
  • q: Variational approximation.
  • prob: The target log-joint likelihood implementing the LogDensityProblem interface.

Keyword Arguments

  • n_samples::Int: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.)
  • entropy::AbstractEntropyEstimator: Entropy estimator. (default: MonteCarloEntropy())

Returns

  • obj_est: Estimate of the objective value.
source

Methodology

Recall that KLMinRepGradDescent maximizes the ELBO. Now, the ELBO can be re-written as follows:

\[ \mathrm{ELBO}\left(q\right) \triangleq \mathcal{E}\left(q\right) + \mathbb{H}\left(q\right),\]

where

\[ \mathcal{E}\left(q\right) = \mathbb{E}_{\theta \sim q} \log \pi\left(\theta\right)\]

is often referred to as the negative energy functional. KLMinRepGradProxDescent attempts to address the fact that minimizing the whole ELBO can be unstable due to non-smoothness of $\mathbb{H}\left(q\right)$[D2020]. For this, KLMinRepGradProxDescent relies on proximal stochastic gradient descent, where the problematic term $\mathbb{H}\left(q\right)$ is separately handled via a proximal operator. Specifically, KLMinRepGradProxDescent first estimates the gradient of the energy $\mathcal{E}\left(q\right)$ only via the reparameterization gradient estimator. Let us denote this as $\widehat{\nabla_{\lambda} \mathcal{E}}\left(q_{\lambda}\right)$. Then KLMinRepGradProxDescent iterates the step

\[ \lambda_{t+1} = \mathrm{prox}_{-\gamma_t \mathbb{H}}\big( \lambda_{t} + \gamma_t \widehat{\nabla_{\lambda} \mathcal{E}}(q_{\lambda_t}) \big) , \]

where

\[\mathrm{prox}_{h}(\lambda_t) = \argmin_{\lambda \in \Lambda}\left\{ h(\lambda) + {\lVert \lambda - \lambda_t \rVert}_2^2 \right\}\]

is a proximal operator for the entropy. As long as $\mathrm{prox}_{-\gamma_t \mathbb{H}}$ can be evaluated efficiently, this scheme can side-step the fact that $\mathbb{H}(\lambda)$ is difficult to deal with via gradient descent. For location-scale families, it turns out the proximal operator of the entropy can be operated efficiently[D2020], which is implemented as ProximalLocationScaleEntropy. This has been empirically shown to be more robust[D2020][KMG2024].

  • D2020Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In International Conference on Machine Learning.
  • KMG2024Kim, K., Ma, Y., & Gardner, J. (2024). Linear Convergence of Black-Box Variational Inference: Should We Stick the Landing?. In International Conference on Artificial Intelligence and Statistics (pp. 235-243). PMLR.
  • DGG2023Domke, J., Gower, R., & Garrigos, G. (2023). Provable convergence guarantees for black-box variational inference. Advances in neural information processing systems, 36, 66289-66327.