General Usage

Each VI algorithm provides the followings:

  1. Variational families supported by each VI algorithm.
  2. A variational objective corresponding to the VI algorithm. Note that each variational family is subject to its own constraints. Thus, please refer to the documentation of the variational inference algorithm of interest.

Optimizing a Variational Objective

After constructing a variational objective objective and initializing a variational approximation, one can optimize objective by calling optimize:

AdvancedVI.optimizeFunction
optimize(problem, objective, q_init, max_iter, objargs...; kwargs...)

Optimize the variational objective objective targeting the problem problem by estimating (stochastic) gradients.

The trainable parameters in the variational approximation are expected to be extractable through Optimisers.destructure. This requires the variational approximation to be marked as a functor through Functors.@functor.

Arguments

  • objective::AbstractVariationalObjective: Variational Objective.
  • q_init: Initial variational distribution. The variational parameters must be extractable through Optimisers.destructure.
  • max_iter::Int: Maximum number of iterations.
  • objargs...: Arguments to be passed to objective.

Keyword Arguments

  • adtype::ADtypes.AbstractADType: Automatic differentiation backend.
  • optimizer::Optimisers.AbstractRule: Optimizer used for inference. (Default: Adam.)
  • averager::AbstractAverager : Parameter averaging strategy. (Default: NoAveraging())
  • operator::AbstractOperator : Operator applied to the parameters after each optimization step. (Default: IdentityOperator())
  • rng::AbstractRNG: Random number generator. (Default: Random.default_rng().)
  • show_progress::Bool: Whether to show the progress bar. (Default: true.)
  • callback: Callback function called after every iteration. See further information below. (Default: nothing.)
  • prog: Progress bar configuration. (Default: ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog).)
  • state::NamedTuple: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.)

Returns

  • averaged_params: Variational parameters generated by the algorithm averaged according to averager.
  • params: Last variational parameters generated by the algorithm.
  • stats: Statistics gathered during optimization.
  • state: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run.

Callback

The callback function callback has a signature of

callback(; stat, state, params, averaged_params, restructure, gradient)

The arguments are as follows:

  • stat: Statistics gathered during the current iteration. The content will vary depending on objective.
  • state: Collection of the internal states used for optimization.
  • params: Variational parameters.
  • averaged_params: Variational parameters averaged according to the averaging strategy.
  • restructure: Function that restructures the variational approximation from the variational parameters. Calling restructure(param) reconstructs the variational approximation.
  • gradient: The estimated (possibly stochastic) gradient.

callback can return a NamedTuple containing some additional information computed within cb. This will be appended to the statistic of the current corresponding iteration. Otherwise, just return nothing.

source

Estimating the Objective

In some cases, it is useful to directly estimate the objective value. This can be done by the following funciton:

AdvancedVI.estimate_objectiveFunction
estimate_objective([rng,] obj, q, prob; kwargs...)

Estimate the variational objective obj targeting prob with respect to the variational approximation q.

Arguments

  • rng::Random.AbstractRNG: Random number generator.
  • obj::AbstractVariationalObjective: Variational objective.
  • prob: The target log-joint likelihood implementing the LogDensityProblem interface.
  • q: Variational approximation.

Keyword Arguments

Depending on the objective, additional keyword arguments may apply. Please refer to the respective documentation of each variational objective for more info.

Returns

  • obj_est: Estimate of the objective value.
source
Info

Note that estimate_objective is not expected to be differentiated through, and may not result in optimal statistical performance.

Advanced Usage

Each variational objective is a subtype of the following abstract type:

AdvancedVI.AbstractVariationalObjectiveType
AbstractVariationalObjective

Abstract type for the VI algorithms supported by AdvancedVI.

Implementations

To be supported by AdvancedVI, a VI algorithm must implement AbstractVariationalObjective and estimate_objective. Also, it should provide gradients by implementing the function estimate_gradient!. If the estimator is stateful, it can implement init to initialize the state.

source

Furthermore, AdvancedVI only interacts with each variational objective by querying gradient estimates. Therefore, to create a new custom objective to be optimized through AdvancedVI, it suffices to implement the following function:

AdvancedVI.estimate_gradient!Function
estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state)

Estimate (possibly stochastic) gradients of the variational objective obj targeting prob with respect to the variational parameters λ

Arguments

  • rng::Random.AbstractRNG: Random number generator.
  • obj::AbstractVariationalObjective: Variational objective.
  • adtype::ADTypes.AbstractADType: Automatic differentiation backend.
  • out::DiffResults.MutableDiffResult: Buffer containing the objective value and gradient estimates.
  • prob: The target log-joint likelihood implementing the LogDensityProblem interface.
  • params: Variational parameters to evaluate the gradient on.
  • restructure: Function that reconstructs the variational approximation from params.
  • obj_state: Previous state of the objective.

Returns

  • out::MutableDiffResult: Buffer containing the objective value and gradient estimates.
  • obj_state: The updated state of the objective.
  • stat::NamedTuple: Statistics and logs generated during estimation.
source

If an objective needs to be stateful, one can implement the following function to inialize the state.

AdvancedVI.initFunction
init(rng, obj, adtype, prob, params, restructure)

Initialize a state of the variational objective obj given the initial variational parameters λ. This function needs to be implemented only if obj is stateful.

Arguments

  • rng::Random.AbstractRNG: Random number generator.
  • obj::AbstractVariationalObjective: Variational objective.

adtype::ADTypes.AbstractADType`: Automatic differentiation backend.

  • params: Initial variational parameters.
  • restructure: Function that reconstructs the variational approximation from λ.
source
init(avg, params)

Initialize the state of the averaging strategy avg with the initial parameters params.

Arguments

  • avg::AbstractAverager: Averaging strategy.
  • params: Initial variational parameters.
source