KLMinRepGradProxDescent
Description
This algorithm is a slight variation of KLMinRepGradDescent specialized to location-scale families. Therefore, it also aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence over the space of parameters. But instead, it uses stochastic proximal gradient descent with the proximal operator of the entropy of location-scale variational families as discussed in: [D2020][KMG2024][DGG2023]. The remainder of the section will only discuss details specific to KLMinRepGradProxDescent. Thus, for general usage and additional details, please refer to the docs of KLMinRepGradDescent instead.
AdvancedVI.KLMinRepGradProxDescent — TypeKLMinRepGradProxDescent(adtype; entropy_zerograd, optimizer, n_samples, averager)KL divergence minimization by running stochastic proximal gradient descent with the reparameterization gradient in the Euclidean space of variational parameters of a location-scale family.
This algorithm only supports subtypes of MvLocationScale. Also, since the stochastic proximal gradient descent does not use the entropy of the gradient, the entropy estimator to be used must have a zero-mean gradient. Thus, only the entropy estimators with a "ZeroGradient" suffix are allowed.
Arguments
adtype: Automatic differentiation backend.
Keyword Arguments
entropy_zerograd: Estimator of the entropy with a zero-mean gradient to be used. Must be one ofClosedFormEntropyZeroGrad,StickingTheLandingEntropyZeroGrad. (default:ClosedFormEntropyZeroGrad())optimizer::Optimisers.AbstractRule: Optimization algorithm to be used. OnlyDoG,DoWGandOptimisers.Descentare supported. (default:DoWG())n_samples::Int: Number of Monte Carlo samples to be used for estimating each gradient.averager::AbstractAverager: Parameter averaging strategy. (default:PolynomialAveraging())subsampling::Union{<:Nothing,<:AbstractSubsampling}: Data point subsampling strategy. Ifnothing, subsampling is not used. (default:nothing)
Output
q_averaged: The variational approximation formed by the averaged SGD iterates.
Callback
The callback function callback has a signature of
callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)The arguments are as follows:
rng: Random number generator internally used by the algorithm.iteration: The index of the current iteration.restructure: Function that restructures the variational approximation from the variational parameters. Callingrestructure(params)reconstructs the current variational approximation.params: Current variational parameters.averaged_params: Variational parameters averaged according to the averaging strategy.gradient: The estimated (possibly stochastic) gradient.
Requirements
- The variational family is
MvLocationScale. - The target distribution and the variational approximation have the same support.
- The target
LogDensityProblems.logdensity(prob, x)must be differentiable with respect toxby the selected AD backend. - Additonal requirements on
qmay apply depending on the choice ofentropy_zerograd.
The associated objective value can be estimated through the following:
AdvancedVI.estimate_objective — Methodestimate_objective([rng,] alg, q, prob; n_samples, entropy)Estimate the ELBO of the variational approximation q against the target log-density prob.
Arguments
rng::Random.AbstractRNG: Random number generator.alg::Union{<:KLMinRepGradDescent,<:KLMinRepGradProxDescent,<:KLMinScoreGradDescent}: Variational inference algorithm.q: Variational approximation.prob: The target log-joint likelihood implementing theLogDensityProbleminterface.
Keyword Arguments
n_samples::Int: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.)entropy::AbstractEntropyEstimator: Entropy estimator. (default:MonteCarloEntropy())
Returns
obj_est: Estimate of the objective value.
Methodology
Recall that KLMinRepGradDescent maximizes the ELBO. Now, the ELBO can be re-written as follows:
\[ \mathrm{ELBO}\left(q\right) \triangleq \mathcal{E}\left(q\right) + \mathbb{H}\left(q\right),\]
where
\[ \mathcal{E}\left(q\right) = \mathbb{E}_{\theta \sim q} \log \pi\left(\theta\right)\]
is often referred to as the negative energy functional. KLMinRepGradProxDescent attempts to address the fact that minimizing the whole ELBO can be unstable due to non-smoothness of $\mathbb{H}\left(q\right)$[D2020]. For this, KLMinRepGradProxDescent relies on proximal stochastic gradient descent, where the problematic term $\mathbb{H}\left(q\right)$ is separately handled via a proximal operator. Specifically, KLMinRepGradProxDescent first estimates the gradient of the energy $\mathcal{E}\left(q\right)$ only via the reparameterization gradient estimator. Let us denote this as $\widehat{\nabla_{\lambda} \mathcal{E}}\left(q_{\lambda}\right)$. Then KLMinRepGradProxDescent iterates the step
\[ \lambda_{t+1} = \mathrm{prox}_{-\gamma_t \mathbb{H}}\big( \lambda_{t} + \gamma_t \widehat{\nabla_{\lambda} \mathcal{E}}(q_{\lambda_t}) \big) , \]
where
\[\mathrm{prox}_{h}(\lambda_t) = \argmin_{\lambda \in \Lambda}\left\{ h(\lambda) + {\lVert \lambda - \lambda_t \rVert}_2^2 \right\}\]
is a proximal operator for the entropy. As long as $\mathrm{prox}_{-\gamma_t \mathbb{H}}$ can be evaluated efficiently, this scheme can side-step the fact that $\mathbb{H}(\lambda)$ is difficult to deal with via gradient descent. For location-scale families, it turns out the proximal operator of the entropy can be operated efficiently[D2020], which is implemented as ProximalLocationScaleEntropy. This has been empirically shown to be more robust[D2020][KMG2024].
- D2020Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In International Conference on Machine Learning.
- KMG2024Kim, K., Ma, Y., & Gardner, J. (2024). Linear Convergence of Black-Box Variational Inference: Should We Stick the Landing?. In International Conference on Artificial Intelligence and Statistics (pp. 235-243). PMLR.
- DGG2023Domke, J., Gower, R., & Garrigos, G. (2023). Provable convergence guarantees for black-box variational inference. Advances in neural information processing systems, 36, 66289-66327.