KLMinRepGradProxDescent
Description
This algorithm is a slight variation of KLMinRepGradDescent
specialized to location-scale families. Therefore, it also aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence over the space of parameters. But instead, it uses stochastic proximal gradient descent with the proximal operator of the entropy of location-scale variational families as discussed in: [D2020][KMG2024][DGG2023]. The remainder of the section will only discuss details specific to KLMinRepGradProxDescent
. Thus, for general usage and additional details, please refer to the docs of KLMinRepGradDescent
instead.
AdvancedVI.KLMinRepGradProxDescent
— TypeKLMinRepGradProxDescent(adtype; entropy_zerograd, optimizer, n_samples, averager)
KL divergence minimization by running stochastic proximal gradient descent with the reparameterization gradient in the Euclidean space of variational parameters of a location-scale family.
This algorithm only supports subtypes of MvLocationScale
. Also, since the stochastic proximal gradient descent does not use the entropy of the gradient, the entropy estimator to be used must have a zero-mean gradient. Thus, only the entropy estimators with a "ZeroGradient" suffix are allowed.
Arguments
adtype
: Automatic differentiation backend.
Keyword Arguments
entropy_zerograd
: Estimator of the entropy with a zero-mean gradient to be used. Must be one ofClosedFormEntropyZeroGrad
,StickingTheLandingEntropyZeroGrad
. (default:ClosedFormEntropyZeroGrad()
)optimizer::Optimisers.AbstractRule
: Optimization algorithm to be used. OnlyDoG
,DoWG
andOptimisers.Descent
are supported. (default:DoWG()
)n_samples::Int
: Number of Monte Carlo samples to be used for estimating each gradient.averager::AbstractAverager
: Parameter averaging strategy. (default:PolynomialAveraging()
)subsampling::Union{<:Nothing,<:AbstractSubsampling}
: Data point subsampling strategy. Ifnothing
, subsampling is not used. (default:nothing
)
Output
q_averaged
: The variational approximation formed by the averaged SGD iterates.
Callback
The callback function callback
has a signature of
callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)
The arguments are as follows:
rng
: Random number generator internally used by the algorithm.iteration
: The index of the current iteration.restructure
: Function that restructures the variational approximation from the variational parameters. Callingrestructure(params)
reconstructs the current variational approximation.params
: Current variational parameters.averaged_params
: Variational parameters averaged according to the averaging strategy.gradient
: The estimated (possibly stochastic) gradient.
Requirements
- The variational family is
MvLocationScale
. - The target distribution and the variational approximation have the same support.
- The target
LogDensityProblems.logdensity(prob, x)
must be differentiable with respect tox
by the selected AD backend. - Additonal requirements on
q
may apply depending on the choice ofentropy_zerograd
.
The associated objective value can be estimated through the following:
AdvancedVI.estimate_objective
— Methodestimate_objective([rng,] alg, q, prob; n_samples, entropy)
Estimate the ELBO of the variational approximation q
against the target log-density prob
.
Arguments
rng::Random.AbstractRNG
: Random number generator.alg::Union{<:KLMinRepGradDescent,<:KLMinRepGradProxDescent,<:KLMinScoreGradDescent}
: Variational inference algorithm.q
: Variational approximation.prob
: The target log-joint likelihood implementing theLogDensityProblem
interface.
Keyword Arguments
n_samples::Int
: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.)entropy::AbstractEntropyEstimator
: Entropy estimator. (default:MonteCarloEntropy()
)
Returns
obj_est
: Estimate of the objective value.
Methodology
Recall that KLMinRepGradDescent maximizes the ELBO. Now, the ELBO can be re-written as follows:
\[ \mathrm{ELBO}\left(q\right) \triangleq \mathcal{E}\left(q\right) + \mathbb{H}\left(q\right),\]
where
\[ \mathcal{E}\left(q\right) = \mathbb{E}_{\theta \sim q} \log \pi\left(\theta\right)\]
is often referred to as the negative energy functional. KLMinRepGradProxDescent
attempts to address the fact that minimizing the whole ELBO can be unstable due to non-smoothness of $\mathbb{H}\left(q\right)$[D2020]. For this, KLMinRepGradProxDescent
relies on proximal stochastic gradient descent, where the problematic term $\mathbb{H}\left(q\right)$ is separately handled via a proximal operator. Specifically, KLMinRepGradProxDescent
first estimates the gradient of the energy $\mathcal{E}\left(q\right)$ only via the reparameterization gradient estimator. Let us denote this as $\widehat{\nabla_{\lambda} \mathcal{E}}\left(q_{\lambda}\right)$. Then KLMinRepGradProxDescent
iterates the step
\[ \lambda_{t+1} = \mathrm{prox}_{-\gamma_t \mathbb{H}}\big( \lambda_{t} + \gamma_t \widehat{\nabla_{\lambda} \mathcal{E}}(q_{\lambda_t}) \big) , \]
where
\[\mathrm{prox}_{h}(\lambda_t) = \argmin_{\lambda \in \Lambda}\left\{ h(\lambda) + {\lVert \lambda - \lambda_t \rVert}_2^2 \right\}\]
is a proximal operator for the entropy. As long as $\mathrm{prox}_{-\gamma_t \mathbb{H}}$ can be evaluated efficiently, this scheme can side-step the fact that $\mathbb{H}(\lambda)$ is difficult to deal with via gradient descent. For location-scale families, it turns out the proximal operator of the entropy can be operated efficiently[D2020], which is implemented as ProximalLocationScaleEntropy
. This has been empirically shown to be more robust[D2020][KMG2024].
- D2020Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In International Conference on Machine Learning.
- KMG2024Kim, K., Ma, Y., & Gardner, J. (2024). Linear Convergence of Black-Box Variational Inference: Should We Stick the Landing?. In International Conference on Artificial Intelligence and Statistics (pp. 235-243). PMLR.
- DGG2023Domke, J., Gower, R., & Garrigos, G. (2023). Provable convergence guarantees for black-box variational inference. Advances in neural information processing systems, 36, 66289-66327.