General Usage

AdvancedVI provides multiple variational inference (VI) algorithms. Each algorithm defines its subtype of AdvancedVI.AbstractAlgorithm with some corresponding methods (see this section). Then the algorithm can be executed by invoking optimize. (See this section).

Optimize

Given a subtype of AbstractAlgorithm associated with each algorithm, it suffices to call the function optimize:

AdvancedVI.optimizeFunction
optimize(
    [rng::Random.AbstractRNG = Random.default_rng(),]
    algorithm::AbstractAlgorithm,
    max_iter::Int,
    prob,
    q_init,
    args...;
    kwargs...
)

Run variational inference algorithm on the problem implementing the LogDensityProblems interface. For more details on the usage, refer to the documentation corresponding to algorithm.

Arguments

  • rng: Random number generator.
  • algorithm: Variational inference algorithm.
  • max_iter::Int: Maximum number of iterations.
  • prob: Target LogDensityProblem
  • q_init: Initial variational distribution.
  • args...: Arguments to be passed to algorithm.

Keyword Arguments

  • show_progress::Bool: Whether to show the progress bar. (Default: true.)
  • state::Union{<:Any,Nothing}: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.)
  • callback: Callback function called after every iteration. See further information below. (Default: nothing.)
  • progress::ProgressMeter.AbstractProgress: Progress bar configuration. (Default: ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog).)
  • kwargs...: Keyword arguments to be passed to algorithm.

Returns

  • output: The output of the variational inference algorithm.
  • info: Array of NamedTuples, where each NamedTuple contains information generated at each iteration.
  • state: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run.

Callback

The signature of the callback function depends on the algorithm in use. Thus, see the documentation for each algorithm. However, a callback should return either a nothing or a NamedTuple containing information generated during the current iteration. The content of the NamedTuple will be concatenated into the corresponding entry in the info array returns in the end of the call to optimize and will be displayed on the progress meter.

source

Each algorithm may interact differently with the arguments of optimize. Therefore, please refer to the documentation of each different algorithm for a detailed description on their behavior and their requirements.

Algorithm Interface

A variational inference algorithm supported by AdvancedVI should define its own subtype of AbstractAlgorithm:

The functionality of each algorithm is then implemented through the following methods:

AdvancedVI.initMethod
init(rng, alg, prob, q_init)

Initialize alg given the initial variational approximation q_init and the target prob.

Arguments

  • rng::Random.AbstractRNG: Random number generator.
  • alg::AbstractAlgorithm: Variational inference algorithm.
  • prob: Target problem.

q_init`: Initial variational approximation.

source
AdvancedVI.stepFunction
step(rng, alg, state, callback, objargs...; kwargs...)

Perform a single step of alg given the previous stat.

Arguments

  • rng::Random.AbstractRNG: Random number generator.
  • alg::AbstractAlgorithm: Variational inference algorithm.
  • state: Previous state of the algorithm.
  • callback: Callback function to be called during the step.

Returns

  • state: New state generated by performing the step.
  • terminate::Bool: Whether to terminate the algorithm after the step.
  • info::NamedTuple: Information generated during the step.
source
AdvancedVI.outputFunction
output(alg, state)

Generate an output variational approximation using the last state of alg.

Arguments

  • alg::AbstractAlgorithm: Variational inference algorithm used to compute the state.
  • state: The last state generated by the algorithm.

Returns

  • out: The output of the algorithm.
source

The role of each method should be self-explanatory and should be clear once we take a look at how optimize interacts with each algorithm. The operation of optimize can be simplified as follows:

function optimize([rng,] algorithm, max_iter, q_init, objargs; kwargs...)
    info_total = NamedTuple[]
    state = init(rng, algorithm, q_init)
    for t in 1:max_iter
        info = (iteration=t,)
        state, terminate, info′ = step(
            rng, algorithm, state, callback, objargs...; kwargs...
        )
        info = merge(info′, info)

        if terminate
            break
        end

        push!(info_total, info)
    end
    out = output(algorithm, state)
    return out, info_total, state
end