General Usage
AdvancedVI provides multiple variational inference (VI) algorithms. Each algorithm defines its subtype of AdvancedVI.AbstractAlgorithm
with some corresponding methods (see this section). Then the algorithm can be executed by invoking optimize
. (See this section).
Optimize
Given a subtype of AbstractAlgorithm
associated with each algorithm, it suffices to call the function optimize
:
AdvancedVI.optimize
— Functionoptimize(
[rng::Random.AbstractRNG = Random.default_rng(),]
algorithm::AbstractAlgorithm,
max_iter::Int,
prob,
q_init,
args...;
kwargs...
)
Run variational inference algorithm
on the problem
implementing the LogDensityProblems
interface. For more details on the usage, refer to the documentation corresponding to algorithm
.
Arguments
rng
: Random number generator.algorithm
: Variational inference algorithm.max_iter::Int
: Maximum number of iterations.prob
: TargetLogDensityProblem
q_init
: Initial variational distribution.args...
: Arguments to be passed toalgorithm
.
Keyword Arguments
show_progress::Bool
: Whether to show the progress bar. (Default:true
.)state::Union{<:Any,Nothing}
: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.)callback
: Callback function called after every iteration. See further information below. (Default:nothing
.)progress::ProgressMeter.AbstractProgress
: Progress bar configuration. (Default:ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)
.)kwargs...
: Keyword arguments to be passed toalgorithm
.
Returns
output
: The output of the variational inference algorithm.info
: Array ofNamedTuple
s, where eachNamedTuple
contains information generated at each iteration.state
: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run.
Callback
The signature of the callback function depends on the algorithm
in use. Thus, see the documentation for each algorithm
. However, a callback should return either a nothing
or a NamedTuple
containing information generated during the current iteration. The content of the NamedTuple
will be concatenated into the corresponding entry in the info
array returns in the end of the call to optimize
and will be displayed on the progress meter.
Each algorithm may interact differently with the arguments of optimize
. Therefore, please refer to the documentation of each different algorithm for a detailed description on their behavior and their requirements.
Algorithm Interface
A variational inference algorithm supported by AdvancedVI
should define its own subtype of AbstractAlgorithm
:
AdvancedVI.AbstractAlgorithm
— TypeAbstractAlgorithm
Abstract type for a variational inference algorithm.
The functionality of each algorithm is then implemented through the following methods:
AdvancedVI.init
— Methodinit(rng, alg, prob, q_init)
Initialize alg
given the initial variational approximation q_init
and the target prob
.
Arguments
rng::Random.AbstractRNG
: Random number generator.alg::AbstractAlgorithm
: Variational inference algorithm.prob
: Target problem.
q_init`: Initial variational approximation.
AdvancedVI.step
— Functionstep(rng, alg, state, callback, objargs...; kwargs...)
Perform a single step of alg
given the previous stat
.
Arguments
rng::Random.AbstractRNG
: Random number generator.alg::AbstractAlgorithm
: Variational inference algorithm.state
: Previous state of the algorithm.callback
: Callback function to be called during the step.
Returns
- state: New state generated by performing the step.
- terminate::Bool: Whether to terminate the algorithm after the step.
- info::NamedTuple: Information generated during the step.
AdvancedVI.output
— Functionoutput(alg, state)
Generate an output variational approximation using the last state
of alg
.
Arguments
alg::AbstractAlgorithm
: Variational inference algorithm used to compute the state.state
: The last state generated by the algorithm.
Returns
out
: The output of the algorithm.
The role of each method should be self-explanatory and should be clear once we take a look at how optimize
interacts with each algorithm. The operation of optimize
can be simplified as follows:
function optimize([rng,] algorithm, max_iter, q_init, objargs; kwargs...)
info_total = NamedTuple[]
state = init(rng, algorithm, q_init)
for t in 1:max_iter
info = (iteration=t,)
state, terminate, info′ = step(
rng, algorithm, state, callback, objargs...; kwargs...
)
info = merge(info′, info)
if terminate
break
end
push!(info_total, info)
end
out = output(algorithm, state)
return out, info_total, state
end