KLMinWassFwdBwd

Description

This algorithm aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence by running proximal gradient descent (also known as forward-backward splitting) in Wasserstein space[DBCS2023]. (This algorithm is also sometimes referred to as "Wasserstein VI".) Since KLMinWassFwdBwd is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (FullRankGaussian) that makes the measure-valued operations tractable.

AdvancedVI.KLMinWassFwdBwdType
KLMinWassFwdBwd(n_samples, stepsize, subsampling)
KLMinWassFwdBwd(; n_samples, stepsize, subsampling)

KL divergence minimization by running stochastic proximal gradient descent (forward-backward splitting) in Wasserstein space[DBCS2023].

This algorithm requires second-order information about the target. If the target LogDensityProblem has second-order differentiation capability, Hessians are used. Otherwise, if the target has only first-order capability, it will use only gradients but this will porbably result in slower convergence and less robust behavior.

(Keyword) Arguments

  • n_samples::Int: Number of samples used to estimate the Wasserstein gradient. (default: 1)
  • stepsize::Float64: Step size of stochastic proximal gradient descent.
  • subsampling::Union{Nothing,<:AbstractSubsampling}: Optional subsampling strategy.
Note

The subsampling strategy is only applied to the target LogDensityProblem but not to the variational approximation q. That is, KLMinWassFwdBwd does not support amortization or structured variational families.

Output

  • q: The last iterate of the algorithm.

Callback Signature

The callback function supplied to optimize needs to have the following signature:

callback(; rng, iteration, q, info)

The keyword arguments are as follows:

  • rng: Random number generator internally used by the algorithm.
  • iteration: The index of the current iteration.
  • q: Current variational approximation.
  • info: NamedTuple containing the information generated during the current iteration.

Requirements

  • The variational family is FullRankGaussian.
  • The target distribution has unconstrained support.
  • The target LogDensityProblems.logdensity(prob, x) has at least first-order differentiation capability.
source

The associated objective value can be estimated through the following:

AdvancedVI.estimate_objectiveMethod
estimate_objective([rng,] alg, q, prob; kwargs...)

Estimate the variational objective to be minimized by the algorithm alg for approximating the target prob with the variational approximation q.

Arguments

  • rng::Random.AbstractRNG: Random number generator.
  • alg::AbstractVariationalAlgorithm: Variational inference algorithm.
  • prob: The target log-joint likelihood implementing the LogDensityProblem interface.
  • q: Variational approximation.

Keyword Arguments

Depending on the algorithm, additional keyword arguments may apply. Please refer to the respective documentation of each algorithm for more info.

Returns

  • obj_est: Estimate of the objective value.
source

Methodology

This algorithm aims to solve the problem

\[ \mathrm{minimize}_{q \in \mathcal{Q}}\quad \mathrm{KL}\left(q, \pi\right)\]

where $\mathcal{Q}$ is some family of distributions, often called the variational family. Since we usually only have access to the unnormalized densities of the target distribution $\pi$, we don't have direct access to the KL divergence. Instead, we focus on minimizing a surrogate objective, the free energy functional, which corresponds to the negated evidence lower bound[JGJS1999], defined as

\[ \mathcal{F}\left(q\right) \triangleq \mathcal{U}\left(q\right) + \mathcal{H}\left(q\right),\]

where

\[\begin{aligned} \mathcal{U}\left(q\right) &= \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right) &&\text{(``potential energy'')} \\ \mathcal{H}\left(q\right) &= \mathbb{E}_{\theta \sim q} \log q\left(\theta\right) . &&\text{(``Boltzmann entropy'')} \end{aligned}\]

For solving this problem, KLMinWassFwdBwd relies on proximal stochastic gradient descent (PSGD)–-also known as "forward-backward splitting"–-that iterates

\[ q_{t+1} = \mathrm{JKO}_{\gamma_t \mathcal{H}}\big( q_{t} - \gamma_t \widehat{\nabla_{\mathrm{BW}} \mathcal{U}} (q_{t}) \big) , \]

where $\widehat{\nabla_{\mathrm{BW}} \mathcal{U}}$ is a stochastic estimate of the Bures-Wasserstein measure-valued gradient of $\mathcal{U}$, the JKO (proximal) operator is defined as

\[\mathrm{JKO}_{\gamma_t \mathcal{H}}(\mu) = \argmin_{\nu \in \mathcal{Q}} \left\{ \mathcal{H}(\nu) + \frac{1}{2 \gamma_t} \mathrm{W}_2 {(\mu, \nu)}^2 \right\} ,\]

and $\mathrm{W}_2$ is the Wasserstein-2 distance. When $\mathcal{Q}$ is set to be the Bures-Wasserstein space of $\mathbb{R}^d$, this algorithm is referred to as the Jordan-Kinderlehrer-Otto (JKO) scheme[JKO1998], which was originally developed to study gradient flows under Wasserstein metrics. Within this context, KLMinWassFwdBwd can be viewed as a numerical realization of the JKO scheme by restricting $\mathcal{Q}$ to be a tractable parametric variational family. Specifically, Diao et al.[DBCS2023] derived the JKO update for multivariate Gaussians, which is implemented by KLMinWassFwdBwd. KLMinWassFwdBwd also exactly corresponds to the measure-space analog of KLMinRepGradProxDescent.

  • DBCS2023Diao, M. Z., Balasubramanian, K., Chewi, S., & Salim, A. (2023). Forward-backward Gaussian variational inference via JKO in the Bures-Wasserstein space. In International Conference on Machine Learning. PMLR.
  • JKO1998Jordan, R., Kinderlehrer, D., & Otto, F. (1998). The variational formulation of the Fokker–Planck equation. SIAM Journal on Mathematical Analysis, 29(1).
  • JGJS1999Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233.