FisherMinBatchMatch

Description

This algorithm, known as batch-and-match (BaM) aims to minimize the covariance-weighted 2nd-order Fisher divergence by running a proximal point-type method[CMPMGBS24]. On certain low-dimensional problems, BaM can converge very quickly without any tuning. Since FisherMinBatchMatch is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (FullRankGaussian) that make the measure-valued operations tractable.

AdvancedVI.FisherMinBatchMatchType
FisherMinBatchMatch(n_samples, subsampling)
FisherMinBatchMatch(; n_samples, subsampling)

Covariance-weighted Fisher divergence minimization via the batch-and-match algorithm, which is a proximal point-type optimization scheme.

(Keyword) Arguments

  • n_samples::Int: Number of samples (batchsize) used to compute the moments required for the batch-and-match update. (default: 32)
  • subsampling::Union{Nothing,<:AbstractSubsampling}: Optional subsampling strategy. (default: nothing)
Warning

FisherMinBatchMatch with subsampling enabled results in a biased algorithm and may not properly optimize the covariance-weighted Fisher divergence.

Note

FisherMinBatchMatch requires a sufficiently large n_samples to converge quickly.

Note

The subsampling strategy is only applied to the target LogDensityProblem but not to the variational approximation q. That is, FisherMinBatchMatch does not support amortization or structured variational families.

Output

  • q: The last iterate of the algorithm.

Callback Signature

The callback function supplied to optimize needs to have the following signature:

callback(; rng, iteration, q, info)

The keyword arguments are as follows:

  • rng: Random number generator internally used by the algorithm.
  • iteration: The index of the current iteration.
  • q: Current variational approximation.
  • info: NamedTuple containing the information generated during the current iteration.

Requirements

  • The variational family is FullRankGaussian.
  • The target distribution has unconstrained support.
  • The target LogDensityProblems.logdensity(prob, x) has at least first-order differentiation capability.
source

The associated objective value can be estimated through the following:

AdvancedVI.estimate_objectiveMethod
estimate_objective([rng,] alg, q, prob; kwargs...)

Estimate the variational objective to be minimized by the algorithm alg for approximating the target prob with the variational approximation q.

Arguments

  • rng::Random.AbstractRNG: Random number generator.
  • alg::AbstractVariationalAlgorithm: Variational inference algorithm.
  • prob: The target log-joint likelihood implementing the LogDensityProblem interface.
  • q: Variational approximation.

Keyword Arguments

Depending on the algorithm, additional keyword arguments may apply. Please refer to the respective documentation of each algorithm for more info.

Returns

  • obj_est: Estimate of the objective value.
source

Methodology

This algorithm aims to solve the problem

\[ \mathrm{minimize}_{q \in \mathcal{Q}}\quad \mathrm{F}_{\mathrm{cov}}(q, \pi),\]

where $\mathcal{Q}$ is some family of distributions, often called the variational family, and $\mathrm{F}_{\mathrm{cov}}$ is a divergence defined as

\[\mathrm{F}_{\mathrm{cov}}(q, \pi) = \mathbb{E}_{z \sim q} {\left\lVert \nabla \log \frac{q}{\pi} (z) \right\rVert}_{\mathrm{Cov}(q)}^2 ,\]

where ${\lVert x \rVert}_{A}^2 = x^{\top} A x $ is a weighted norm. $\mathrm{F}_{\mathrm{cov}}$ can be viewed as a variant of the canonical 2nd-order Fisher divergence defined as

\[\mathrm{F}_{2}(q, \pi) = \sqrt{ \mathbb{E}_{z \sim q} {\left\lVert \nabla \log \frac{q}{\pi} (z) \right\rVert}^2 }.\]

The use of the weighted norm ${\lVert \cdot \rVert}_{\mathrm{Cov}(q)}^2$ facilitates the use of a proximal point-type method for minimizing $\mathrm{F}_{2}(q, \pi)$. In particular, BaM iterates the update

\[ q_{t+1} = \argmin_{q \in \mathcal{Q}} \left\{ \mathrm{F}_{\mathrm{cov}}(q, \pi) + \frac{2}{\lambda_t} \mathrm{KL}\left(q_t, q\right) \right\} .\]

Since $\mathrm{F}(q, \pi)$ is intractable, it is replaced with a Monte Carlo approximation with a number of samples n_samples. Furthermore, by restricting $\mathcal{Q}$ to a Gaussian variational family, the update rule admits a closed form solution[CMPMGBS24]. Notice that the update does not involve the parameterization of $q_t$, which makes FisherMinBatchMatch a measure-space algorithm.

Historically, the idea of using a proximal point-type update for minimizing a Fisher divergence-like objective was initially coined as Gaussian score matching[MGMYBS23]. BaM can be viewed as a successor to this algorithm.

  • CMPMGBS24Cai, D., Modi, C., Pillaud-Vivien, L., Margossian, C. C., Gower, R. M., Blei, D. M., & Saul, L. K. (2024). Batch and match: black-box variational inference with a score-based divergence. In Proceedings of the International Conference on Machine Learning.
  • MGMYBS23Modi, C., Gower, R., Margossian, C., Yao, Y., Blei, D., & Saul, L. (2023). Variational inference with Gaussian score matching. In Advances in Neural Information Processing Systems, 36.