API: Turing.Variational
Turing.Variational.q_fullrank_gaussian
— Methodq_fullrank_gaussian(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
location::Union{Nothing,<:AbstractVector} = nothing,
scale::Union{Nothing,<:LowerTriangular} = nothing,
kwargs...
)
Find a numerically non-degenerate Gaussian q
with a scale with full-rank factors (traditionally referred to as a "full-rank family") for approximating the target model
.
Arguments
model
: The targetDynamicPPL.Model
.
Keyword Arguments
location
: The location parameter of the initialization. Ifnothing
, a vector of zeros is used.scale
: The scale parameter of the initialization. Ifnothing
, an identity matrix is used.
The remaining keyword arguments are passed to q_locationscale
.
Returns
q::Bijectors.TransformedDistribution
: AAdvancedVI.LocationScale
distribution matching the support ofmodel
.
Turing.Variational.q_initialize_scale
— Methodq_initialize_scale(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model,
location::AbstractVector,
scale::AbstractMatrix,
basedist::Distributions.UnivariateDistribution;
num_samples::Int = 10,
num_max_trials::Int = 10,
reduce_factor::Real = one(eltype(scale)) / 2
)
Given an initial location-scale distribution q
formed by location
, scale
, and basedist
, shrink scale
until the expectation of log-densities of model
taken over q
are finite. If the log-densities are not finite even after num_max_trials
, throw an error.
For reference, a location-scale distribution $q$ formed by location
, scale
, and basedist
is a distribution where its sampling process $z \sim q$ can be represented as
u = rand(basedist, d)
z = scale * u + location
Arguments
model
: The targetDynamicPPL.Model
.location
: The location parameter of the initialization.scale
: The scale parameter of the initialization.basedist
: The base distribution of the location-scale family.
Keyword Arguments
num_samples
: Number of samples used to compute the average log-density at each trial.num_max_trials
: Number of trials until throwing an error.reduce_factor
: Factor for shrinking the scale. Aftern
trials, the scale is thenscale*reduce_factor^n
.
Returns
scale_adj
: The adjusted scale matrix matching the type ofscale
.
Turing.Variational.q_locationscale
— Methodq_locationscale(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
location::Union{Nothing,<:AbstractVector} = nothing,
scale::Union{Nothing,<:Diagonal,<:LowerTriangular} = nothing,
meanfield::Bool = true,
basedist::Distributions.UnivariateDistribution = Normal()
)
Find a numerically non-degenerate variational distribution q
for approximating the target model
within the location-scale variational family formed by the type of scale
and basedist
.
The distribution can be manually specified by setting location
, scale
, and basedist
. Otherwise, it chooses a standard Gaussian by default. Whether the default choice is used or not, the scale
may be adjusted via q_initialize_scale
so that the log-densities of model
are finite over the samples from q
. If meanfield
is set as true
, the scale of q
is restricted to be a diagonal matrix and only the diagonal of scale
is used.
For reference, a location-scale distribution $q$ formed by location
, scale
, and basedist
is a distribution where its sampling process $z \sim q$ can be represented as
u = rand(basedist, d)
z = scale * u + location
Arguments
model
: The targetDynamicPPL.Model
.
Keyword Arguments
location
: The location parameter of the initialization. Ifnothing
, a vector of zeros is used.scale
: The scale parameter of the initialization. Ifnothing
, an identity matrix is used.meanfield
: Whether to use the mean-field approximation. Iftrue
,scale
is converted into aDiagonal
matrix. Otherwise, it is converted into aLowerTriangular
matrix.basedist
: The base distribution of the location-scale family.
The remaining keywords are passed to q_initialize_scale
.
Returns
q::Bijectors.TransformedDistribution
: AAdvancedVI.LocationScale
distribution matching the support ofmodel
.
Turing.Variational.q_meanfield_gaussian
— Methodq_meanfield_gaussian(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
location::Union{Nothing,<:AbstractVector} = nothing,
scale::Union{Nothing,<:Diagonal} = nothing,
kwargs...
)
Find a numerically non-degenerate mean-field Gaussian q
for approximating the target model
.
Arguments
model
: The targetDynamicPPL.Model
.
Keyword Arguments
location
: The location parameter of the initialization. Ifnothing
, a vector of zeros is used.scale
: The scale parameter of the initialization. Ifnothing
, an identity matrix is used.
The remaining keyword arguments are passed to q_locationscale
.
Returns
q::Bijectors.TransformedDistribution
: AAdvancedVI.LocationScale
distribution matching the support ofmodel
.
Turing.Variational.vi
— Methodvi(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
q,
n_iterations::Int;
objective::AdvancedVI.AbstractVariationalObjective = AdvancedVI.RepGradELBO(
10; entropy = AdvancedVI.ClosedFormEntropyZeroGradient()
),
show_progress::Bool = Turing.PROGRESS[],
optimizer::Optimisers.AbstractRule = AdvancedVI.DoWG(),
averager::AdvancedVI.AbstractAverager = AdvancedVI.PolynomialAveraging(),
operator::AdvancedVI.AbstractOperator = AdvancedVI.ProximalLocationScaleEntropy(),
adtype::ADTypes.AbstractADType = Turing.DEFAULT_ADTYPE,
kwargs...
)
Approximating the target model
via variational inference by optimizing objective
with the initialization q
. This is a thin wrapper around AdvancedVI.optimize
.
Arguments
model
: The targetDynamicPPL.Model
.q
: The initial variational approximation.n_iterations
: Number of optimization steps.
Keyword Arguments
objective
: Variational objective to be optimized.show_progress
: Whether to show the progress bar.optimizer
: Optimization algorithm.averager
: Parameter averaging strategy.operator
: Operator applied after each optimization step.adtype
: Automatic differentiation backend.
See the docs of AdvancedVI.optimize
for additional keyword arguments.
Returns
q
: Variational distribution formed by the last iterate of the optimization run.q_avg
: Variational distribution formed by the averaged iterates according toaverager
.state
: Collection of states used for optimization. This can be used to resume from a past call tovi
.info
: Information generated during the optimization run.