API: Turing.Variational
Turing.Variational.q_fullrank_gaussian — Methodq_fullrank_gaussian(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
location::Union{Nothing,<:AbstractVector} = nothing,
scale::Union{Nothing,<:LowerTriangular} = nothing,
kwargs...
)Find a numerically non-degenerate Gaussian q with a scale with full-rank factors (traditionally referred to as a "full-rank family") for approximating the target model.
Arguments
model: The targetDynamicPPL.Model.
Keyword Arguments
location: The location parameter of the initialization. Ifnothing, a vector of zeros is used.scale: The scale parameter of the initialization. Ifnothing, an identity matrix is used.
The remaining keyword arguments are passed to q_locationscale.
Returns
q::Bijectors.TransformedDistribution: AAdvancedVI.LocationScaledistribution matching the support ofmodel.
Turing.Variational.q_initialize_scale — Methodq_initialize_scale(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model,
location::AbstractVector,
scale::AbstractMatrix,
basedist::Distributions.UnivariateDistribution;
num_samples::Int = 10,
num_max_trials::Int = 10,
reduce_factor::Real = one(eltype(scale)) / 2
)Given an initial location-scale distribution q formed by location, scale, and basedist, shrink scale until the expectation of log-densities of model taken over q are finite. If the log-densities are not finite even after num_max_trials, throw an error.
For reference, a location-scale distribution $q$ formed by location, scale, and basedist is a distribution where its sampling process $z \sim q$ can be represented as
u = rand(basedist, d)
z = scale * u + locationArguments
model: The targetDynamicPPL.Model.location: The location parameter of the initialization.scale: The scale parameter of the initialization.basedist: The base distribution of the location-scale family.
Keyword Arguments
num_samples: Number of samples used to compute the average log-density at each trial.num_max_trials: Number of trials until throwing an error.reduce_factor: Factor for shrinking the scale. Afterntrials, the scale is thenscale*reduce_factor^n.
Returns
scale_adj: The adjusted scale matrix matching the type ofscale.
Turing.Variational.q_locationscale — Methodq_locationscale(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
location::Union{Nothing,<:AbstractVector} = nothing,
scale::Union{Nothing,<:Diagonal,<:LowerTriangular} = nothing,
meanfield::Bool = true,
basedist::Distributions.UnivariateDistribution = Normal()
)Find a numerically non-degenerate variational distribution q for approximating the target model within the location-scale variational family formed by the type of scale and basedist.
The distribution can be manually specified by setting location, scale, and basedist. Otherwise, it chooses a standard Gaussian by default. Whether the default choice is used or not, the scale may be adjusted via q_initialize_scale so that the log-densities of model are finite over the samples from q. If meanfield is set as true, the scale of q is restricted to be a diagonal matrix and only the diagonal of scale is used.
For reference, a location-scale distribution $q$ formed by location, scale, and basedist is a distribution where its sampling process $z \sim q$ can be represented as
u = rand(basedist, d)
z = scale * u + locationArguments
model: The targetDynamicPPL.Model.
Keyword Arguments
location: The location parameter of the initialization. Ifnothing, a vector of zeros is used.scale: The scale parameter of the initialization. Ifnothing, an identity matrix is used.meanfield: Whether to use the mean-field approximation. Iftrue,scaleis converted into aDiagonalmatrix. Otherwise, it is converted into aLowerTriangularmatrix.basedist: The base distribution of the location-scale family.
The remaining keywords are passed to q_initialize_scale.
Returns
q::Bijectors.TransformedDistribution: AAdvancedVI.LocationScaledistribution matching the support ofmodel.
Turing.Variational.q_meanfield_gaussian — Methodq_meanfield_gaussian(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
location::Union{Nothing,<:AbstractVector} = nothing,
scale::Union{Nothing,<:Diagonal} = nothing,
kwargs...
)Find a numerically non-degenerate mean-field Gaussian q for approximating the target model.
Arguments
model: The targetDynamicPPL.Model.
Keyword Arguments
location: The location parameter of the initialization. Ifnothing, a vector of zeros is used.scale: The scale parameter of the initialization. Ifnothing, an identity matrix is used.
The remaining keyword arguments are passed to q_locationscale.
Returns
q::Bijectors.TransformedDistribution: AAdvancedVI.LocationScaledistribution matching the support ofmodel.
Turing.Variational.vi — Methodvi(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
q,
n_iterations::Int;
objective::AdvancedVI.AbstractVariationalObjective = AdvancedVI.RepGradELBO(
10; entropy = AdvancedVI.ClosedFormEntropyZeroGradient()
),
show_progress::Bool = Turing.PROGRESS[],
optimizer::Optimisers.AbstractRule = AdvancedVI.DoWG(),
averager::AdvancedVI.AbstractAverager = AdvancedVI.PolynomialAveraging(),
operator::AdvancedVI.AbstractOperator = AdvancedVI.ProximalLocationScaleEntropy(),
adtype::ADTypes.AbstractADType = Turing.DEFAULT_ADTYPE,
kwargs...
)Approximating the target model via variational inference by optimizing objective with the initialization q. This is a thin wrapper around AdvancedVI.optimize.
Arguments
model: The targetDynamicPPL.Model.q: The initial variational approximation.n_iterations: Number of optimization steps.
Keyword Arguments
objective: Variational objective to be optimized.show_progress: Whether to show the progress bar.optimizer: Optimization algorithm.averager: Parameter averaging strategy.operator: Operator applied after each optimization step.adtype: Automatic differentiation backend.
See the docs of AdvancedVI.optimize for additional keyword arguments.
Returns
q: Variational distribution formed by the last iterate of the optimization run.q_avg: Variational distribution formed by the averaged iterates according toaverager.state: Collection of states used for optimization. This can be used to resume from a past call tovi.info: Information generated during the optimization run.