using Random
using Turing
using Turing: Variational
using AdvancedVI
using Plots
Random.seed!(42);Variational Inference
This post will look at variational inference (VI), an optimisation approach to approximate Bayesian inference, and how to use it in Turing.jl as an alternative to other approaches such as MCMC. This post will focus on the usage of VI in Turing rather than the principles and theory underlying VI. If you are interested in understanding the mathematics you can checkout our write-up or any other resource online (there are a lot of great ones).
Let’s start with a minimal example. Consider a Turing.Model, which we denote as model. Approximating the posterior associated with model via VI is as simple as
m = model(data...) # instantiate model on the data
q_init = q_fullrank_gaussian # how to generate initial variational approximation
result = vi(m, q_init, 1000) # perform VI with the default algorithm on `m` for 1000 iterationsThus, it’s no more work than standard MCMC sampling in Turing. The default algorithm uses stochastic gradient descent to minimise the (exclusive) KL divergence. This approach is commonly referred to as automatic differentiation variational inference (ADVI)1, stochastic gradient VI2, and black-box variational inference3 with the reparameterization gradient456.
To get a bit more into what we can do with VI, let’s look at a more concrete example. We will reproduce the tutorial on Bayesian linear regression using VI instead of MCMC. After that, we will discuss how to customise the behaviour of vi for more advanced usage.
Let’s first import the relevant packages:
Bayesian Linear Regression Example
Let’s start by setting up our example. We will re-use the Bayesian linear regression example. As we’ll see, there is really no additional work required to apply variational inference to a more complex Model.
using FillArrays
using RDatasets
using LinearAlgebra
# Import the "mtcars" dataset.
data = RDatasets.dataset("datasets", "mtcars");
# Show the first six rows of the dataset.
first(data, 6)| Row | Model | MPG | Cyl | Disp | HP | DRat | WT | QSec | VS | AM | Gear | Carb |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| String31 | Float64 | Int64 | Float64 | Int64 | Float64 | Float64 | Float64 | Int64 | Int64 | Int64 | Int64 | |
| 1 | Mazda RX4 | 21.0 | 6 | 160.0 | 110 | 3.9 | 2.62 | 16.46 | 0 | 1 | 4 | 4 |
| 2 | Mazda RX4 Wag | 21.0 | 6 | 160.0 | 110 | 3.9 | 2.875 | 17.02 | 0 | 1 | 4 | 4 |
| 3 | Datsun 710 | 22.8 | 4 | 108.0 | 93 | 3.85 | 2.32 | 18.61 | 1 | 1 | 4 | 1 |
| 4 | Hornet 4 Drive | 21.4 | 6 | 258.0 | 110 | 3.08 | 3.215 | 19.44 | 1 | 0 | 3 | 1 |
| 5 | Hornet Sportabout | 18.7 | 8 | 360.0 | 175 | 3.15 | 3.44 | 17.02 | 0 | 0 | 3 | 2 |
| 6 | Valiant | 18.1 | 6 | 225.0 | 105 | 2.76 | 3.46 | 20.22 | 1 | 0 | 3 | 1 |
# Function to split samples.
function split_data(df, at=0.70)
r = size(df, 1)
index = Int(round(r * at))
train = df[1:index, :]
test = df[(index + 1):end, :]
return train, test
end
# A handy helper function to rescale our dataset.
function standardise(x)
return (x .- mean(x; dims=1)) ./ std(x; dims=1)
end
function standardise(x, orig)
return (x .- mean(orig; dims=1)) ./ std(orig; dims=1)
end
# Another helper function to unstandardize our datasets.
function unstandardize(x, orig)
return x .* std(orig; dims=1) .+ mean(orig; dims=1)
end
function unstandardize(x, mean_train, std_train)
return x .* std_train .+ mean_train
endunstandardize (generic function with 2 methods)
# Remove the model column.
select!(data, Not(:Model))
# Split our dataset 70%/30% into training/test sets.
train, test = split_data(data, 0.7)
train_unstandardized = copy(train)
# Standardise both datasets.
std_train = standardise(Matrix(train))
std_test = standardise(Matrix(test), Matrix(train))
# Save dataframe versions of our dataset.
train_cut = DataFrame(std_train, names(data))
test_cut = DataFrame(std_test, names(data))
# Create our labels. These are the values we are trying to predict.
train_label = train_cut[:, :MPG]
test_label = test_cut[:, :MPG]
# Get the list of columns to keep.
remove_names = filter(x -> !in(x, ["MPG"]), names(data))
# Filter the test and train sets.
train = Matrix(train_cut[:, remove_names]);
test = Matrix(test_cut[:, remove_names]);# Bayesian linear regression.
@model function linear_regression(x, y, n_obs, n_vars, ::Type{T}=Vector{Float64}) where {T}
# Set variance prior.
σ² ~ truncated(Normal(0, 100); lower=0)
# Set intercept prior.
intercept ~ Normal(0, 3)
# Set the priors on our coefficients.
coefficients ~ MvNormal(Zeros(n_vars), 10.0 * I)
# Calculate all the mu terms.
mu = intercept .+ x * coefficients
return y ~ MvNormal(mu, σ² * I)
end;
n_obs, n_vars = size(train)
m = linear_regression(train, train_label, n_obs, n_vars)DynamicPPL.Model{typeof(linear_regression), (:x, :y, :n_obs, :n_vars, Symbol("##arg#232")), (), (), Tuple{Matrix{Float64}, Vector{Float64}, Int64, Int64, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext, false}(linear_regression, (x = [-0.10474629747086901 -0.5742368526626253 … 1.0702591020190138 1.1186855116463725; -0.10474629747086901 -0.5742368526626253 … 1.0702591020190138 1.1186855116463725; … ; -1.256955569650429 -0.8899286603079263 … -0.8918825850158448 -1.2630320292781625; 1.0474629747086912 0.6758710573112484 … -0.8918825850158448 -0.46912618230331743], y = [0.16337978215187793, 0.16337978215187793, 0.452211897027519, 0.2275646965686868, -0.20568347574477455, -0.30196084736998785, -0.9117175343296745, 0.7089515546947551, 0.452211897027519, -0.12545233272376316 … -0.4303306762036062, -0.7673014768918542, -1.5375204498935633, -1.5375204498935633, -0.8475326199128657, 1.9926498430309372, 1.6717252709468917, 2.2333432720939714, 0.2436109251728893, -0.7191627910792473], n_obs = 22, n_vars = 10, var"##arg#232" = DynamicPPL.TypeWrap{Vector{Float64}}()), NamedTuple(), DynamicPPL.DefaultContext())
Basic Usage
To run VI, we must first set a variational family. For instance, the most commonly used family is the mean-field Gaussian family. For this, Turing provides functions that automatically construct the initialisation corresponding to the model m. For example, for the mean-field Gaussian family, we can use:
@doc(Variational.q_meanfield_gaussian)q_meanfield_gaussian(
[rng::Random.AbstractRNG,]
ldf::DynamicPPL.LogDensityFunction;
location::Union{Nothing,<:AbstractVector} = nothing,
scale::Union{Nothing,<:Diagonal} = nothing,
kwargs...
)
Find a numerically non-degenerate mean-field Gaussian q for approximating the target ldf::LogDensityFunction.
If the scale set as nothing, the default value will be a zero-mean Gaussian with a Diagonal scale matrix (the "mean-field" approximation) no larger than 0.6*I (covariance of 0.6^2*I). This guarantees that the samples from the initial variational approximation will fall in the range of (-2, 2) with 99.9% probability, which mimics the behavior of the Turing.InitFromUniform() strategy. Whether the default choice is used or not, the scale may be adjusted via q_initialize_scale so that the log-densities of model are finite over the samples from q.
Arguments
ldf: The target log-density function.
Keyword Arguments
location: The location parameter of the initialization. Ifnothing, a vector of zeros is used.scale: The scale parameter of the initialization. Ifnothing, an identity matrix is used.
The remaining keyword arguments are passed to q_locationscale.
Returns
An
AdvancedVI.LocationScaledistribution matching the support ofldf.
As we can see, the precise initialisation can be customized through the keyword arguments.
Let’s run VI with the default setting:
n_iters = 1000
result_avg = vi(m, q_meanfield_gaussian, n_iters; show_progress=false)[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
VIResult
├ q : MvLocationScale
├ info : 1000-element Vector{@NamedTuple{elbo::Float64, iteration::Int64}}
│ final iteration:
│ ├ elbo = -55.039471145948
│ └ iteration = 1000
└ (2 more fields: state, ldf)
The default setting uses the AdvancedVI.RepGradELBO objective, which corresponds to a variant of what is known as automatic differentiation VI7 or stochastic gradient VI8 or black-box VI9 with the reparameterization gradient101112. The default optimiser we use is AdvancedVI.DoWG13 combined with a proximal operator. (The use of proximal operators with VI on a location-scale family is discussed in detail by J. Domke1415 and others16.) We will take a deeper look into the returned values and the keyword arguments in the following subsections. First, here is the full documentation for vi:
@doc(Variational.vi)vi(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model,
family,
max_iter::Int;
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
algorithm::AdvancedVI.AbstractVariationalAlgorithm = KLMinRepGradProxDescent(
adtype; n_samples=10
),
unconstrained::Bool=requires_unconstrained_space(algorithm),
fix_transforms::Bool=false,
show_progress::Bool = Turing.PROGRESS[],
kwargs...
)
Approximate the target model via the variational inference algorithm algorithm using a variational family specified by family. This is a thin wrapper around AdvancedVI.optimize.
The default algorithm, KLMinRepGradProxDescent (relevant docs), assumes family returns a AdvancedVI.MvLocationScale, which is true if family is q_fullrank_gaussian or q_meanfield_gaussian. For other variational families, refer to the documentation of AdvancedVI to determine the best algorithm and other options.
Arguments
model: The targetDynamicPPL.Model.family: A function which is used to generate an initial variational approximation. Existing choices in Turing areq_locationscale,q_meanfield_gaussian, andq_fullrank_gaussian.max_iter: Maximum number of steps.Any additional arguments are passed on to
AdvancedVI.optimize.
Keyword Arguments
adtype: Automatic differentiation backend to be applied to the log-density. The default value foralgorithmalso uses this backend for differentiating the variational objective.algorithm: Variational inference algorithm. The default isKLMinRepGradProxDescent, please refer to AdvancedVI docs for all the options.show_progress: Whether to show the progress bar.unconstrained: Whether to transform the posterior to be unconstrained for running the variational inference algorithm. The default value depends on the chosenalgorithm(most algorithms require unconstrained space).fix_transforms: Whether to precompute the transforms needed to convert model parameters to (possibly unconstrained) vectors. This can lead to performance improvements, but if any transforms depend on model parameters, settingfix_transforms=truecan silently yield incorrect results.Any additional keyword arguments are passed on both to the function
initial_approx, and also toAdvancedVI.optimize.
See the docs of AdvancedVI.optimize for additional keyword arguments.
Returns
A VIResult object: please see its docstring for information.
Values Returned by vi
vi returns a single object, which is a VIResult.
The main output of the algorithm is the q field, which is the variational approximation itself. This represents the average of the parameters generated by the optimisation algorithm, and is computed using what is known as polynomial averaging17.
result_avg.qAdvancedVI.MvLocationScale{LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Distributions.Normal{Float64}, Vector{Float64}}(
location: [-1.370577232223292, -0.002238039832517606, 0.36984333017795934, -0.10267782351028602, -0.0809317588605051, 0.6051904509755113, -0.0035549406962756976, 0.08379502327038048, -0.06794334563789456, 0.12488444636147826, 0.1889591491241175, -0.6093667705331203]
scale: [0.3261788171235718 0.0 … 0.0 0.0; 0.0 0.10439317655945271 … 0.0 0.0; … ; 0.0 0.0 … 0.11152949244728395 0.0; 0.0 0.0 … 0.0 0.11076139211388376]
dist: Distributions.Normal{Float64}(μ=0.0, σ=1.0)
)
Usually, averaging will lead to better performance than simply using the last iterate. If you want the last iterate, you can disable averaging in the algorithm:
result_last = vi(
m,
q_meanfield_gaussian,
n_iters;
show_progress=false,
algorithm=KLMinRepGradDescent(
AutoForwardDiff();
operator=AdvancedVI.ClipScale(),
averager=AdvancedVI.NoAveraging()
),
);
result_last.q[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
AdvancedVI.MvLocationScale{LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Distributions.Normal{Float64}, Vector{Float64}}(
location: [-0.21504162166797902, 0.012446511297543199, 0.13903845336976778, -0.027392498457864435, -0.06570618926664787, 0.4449812540648866, -0.14239638987492198, 0.10732677783503532, -0.009381506573922577, 0.13017362091806114, 0.10816612142257093, -0.4218323658540159]
scale: [0.25703090224020275 0.0 … 0.0 0.0; 0.0 0.17533886392073122 … 0.0 0.0; … ; 0.0 0.0 … 0.13098781526771058 0.0; 0.0 0.0 … 0.0 0.15777128688699263]
dist: Distributions.Normal{Float64}(μ=0.0, σ=1.0)
)
To see the difference, we can compare the ELBO of the two:
@info("Objective of q_avg and q_last",
ELBO_q_avg = estimate_objective(AdvancedVI.RepGradELBO(32), result_avg.q, result_avg.ldf),
ELBO_q_last = estimate_objective(AdvancedVI.RepGradELBO(32), result_last.q, result_last.ldf)
)┌ Info: Objective of q_avg and q_last │ ELBO_q_avg = 51.454598200575475 └ ELBO_q_last = 68.6237892021251
We can see that ELBO_q_avg is slightly more optimal.
result.info also contains information generated during optimisation that could be useful for diagnostics. For the default setting, which is RepGradELBO, it contains the ELBO estimated at each step, which can be plotted as follows:
Plots.plot([i.elbo for i in result_avg.info], xlabel="Iterations", ylabel="ELBO", label="info")Since the ELBO is estimated by a small number of samples, it appears noisy. Furthermore, at each step, the ELBO is evaluated on q_last, not q_avg, which is the actual output that we care about. To obtain more accurate ELBO estimates evaluated on q_avg, we have to define a custom callback function.
Custom Callback Functions
To inspect the progress of optimisation in more detail, one can define a custom callback function. For example, the following callback function estimates the ELBO on q_avg every 10 steps with a larger number of samples:
function callback(; iteration, averaged_params, restructure, state, kwargs...)
if mod(iteration, 10) == 1
q_avg = restructure(averaged_params)
obj = AdvancedVI.RepGradELBO(128) # 128 samples for ELBO estimation
elbo_avg = -estimate_objective(obj, q_avg, state.prob)
(elbo_avg = elbo_avg,)
else
nothing
end
end;The NamedTuple returned by callback will be appended to the corresponding entry of info, and it will also be displayed on the progress meter if show_progress is set as true.
The custom callback can be supplied to vi as a keyword argument:
result_mf = vi(m, q_meanfield_gaussian, n_iters; show_progress=false, callback=callback);[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
Let’s plot the result:
info_mf = result_mf.info
iters = 1:10:length(info_mf)
elbo_mf = [i.elbo_avg for i in info_mf[iters]]
Plots.plot([i.elbo for i in info_mf], xlabel="Iterations", ylabel="ELBO", label="info", linewidth=0.4)
Plots.plot!(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="callback", ylims=(-200,Inf), linewidth=2)We can see that the ELBO values are less noisy and progress more smoothly due to averaging.
Using Different Optimisers
The default optimiser we use is a proximal variant of DoWG18. For Gaussian variational families, this works well as a default option. Sometimes, the step size of AdvancedVI.DoWG could be too large, resulting in unstable behaviour. (In this case, we recommend trying AdvancedVI.DoG19) Or, for whatever reason, it might be desirable to use a different optimiser. Our implementation supports any optimiser that implements the Optimisers.jl interface.
For instance, let’s try using Optimisers.Adam20, which is a popular choice. Since AdvancedVI does not implement a proximal operator for Optimisers.Adam, we must use the AdvancedVI.ClipScale() projection operator, which ensures that the scale matrix of the variational approximation is positive definite. (See the paper by J. Domke 202021 for more detail about the use of a projection operator.)
using Optimisers
result = vi(
m, q_meanfield_gaussian, n_iters;
show_progress=false,
callback=callback,
algorithm=KLMinRepGradDescent(AutoForwardDiff(); optimizer=Optimisers.Adam(3e-3), operator=ClipScale())
);[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
info_adam = result.info
iters = 1:10:length(info_adam)
elbo_adam = [i.elbo_avg for i in info_adam[iters]]
Plots.plot(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="DoWG")
Plots.plot!(iters, elbo_adam, xlabel="Iterations", ylabel="ELBO", label="Adam")Compared to the default option AdvancedVI.DoWG(), we can see that Optimisers.Adam(3e-3) is converging more slowly. With more step size tuning, it is possible that Optimisers.Adam could perform better or equal. That is, most common optimisers require some degree of tuning to perform better or comparably to AdvancedVI.DoWG() or AdvancedVI.DoG(), which do not require much tuning at all. Due to this fact, they are referred to as parameter-free optimizers.
Using Full-Rank Variational Families
So far, we have only used the mean-field Gaussian family. This, however, approximates the posterior covariance with a diagonal matrix. To model the full covariance matrix, we can use the full-rank Gaussian family2223:
@doc(Variational.q_fullrank_gaussian)q_fullrank_gaussian(
[rng::Random.AbstractRNG,]
ldf::DynamicPPL.LogDensityFunction;
location::Union{Nothing,<:AbstractVector} = nothing,
scale::Union{Nothing,<:LowerTriangular} = nothing,
kwargs...
)
Find a numerically non-degenerate Gaussian q with a scale with full-rank factors (traditionally referred to as a "full-rank family") for approximating the target ldf::LogDensityFunction.
If the scale set as nothing, the default value will be a zero-mean Gaussian with a LowerTriangular scale matrix (resulting in a covariance with "full-rank" factors) no larger than 0.6*I (covariance of 0.6^2*I). This guarantees that the samples from the initial variational approximation will fall in the range of (-2, 2) with 99.9% probability, which mimics the behavior of the Turing.InitFromUniform() strategy. Whether the default choice is used or not, the scale may be adjusted via q_initialize_scale so that the log-densities of model are finite over the samples from q.
Arguments
ldf: The target log-density function.
Keyword Arguments
location: The location parameter of the initialization. Ifnothing, a vector of zeros is used.scale: The scale parameter of the initialization. Ifnothing, an identity matrix is used.
The remaining keyword arguments are passed to q_locationscale.
Returns
An
AdvancedVI.LocationScaledistribution matching the support ofldf.
The term full-rank might seem a bit peculiar since covariance matrices are always full-rank. This term, however, traditionally comes from the fact that full-rank families use full-rank factors in addition to the diagonal of the covariance.
In contrast to the mean-field family, the full-rank family will often result in more computation per optimisation step and slower convergence, especially in high dimensions:
result_fr = vi(m, q_fullrank_gaussian, n_iters; show_progress=false, callback)
Plots.plot(elbo_mf, xlabel="Iterations", ylabel="ELBO", label="Mean-Field", ylims=(-200, Inf))
elbo_fr = [i.elbo_avg for i in result_fr.info[iters]]
Plots.plot!(elbo_fr, xlabel="Iterations", ylabel="ELBO", label="Full-Rank", ylims=(-200, Inf))[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
However, we can see that the full-rank families achieve a higher ELBO in the end. Due to the relationship between the ELBO and the Kullback–Leibler divergence, this indicates that the full-rank covariance is much more accurate. This trade-off between statistical accuracy and optimisation speed is often referred to as the statistical-computational trade-off. The fact that we can control this trade-off through the choice of variational family is a strength, rather than a limitation, of variational inference.
Obtaining Summary Statistics
Let’s inspect the resulting variational approximation in more detail and compare it against MCMC. To obtain summary statistics from VI, we can draw samples from the resulting variational approximation by calling rand on the VIResult. Note that even though VI is often performed in transformed space, rand(result) will return samples in the original space of the parameters.
rand(result_fr)VarNamedTuple ├─ σ² => 0.25200558565903575 ├─ intercept => -0.09840830441493338 └─ coefficients => [1.398319273658256, -0.9150343171917702, 0.04933566000406464, 0.8062034705455388, 0.39327430346864284, 0.6709130436782058, -0.6412192737889927, 0.1533691597033213, 0.5355775921267156, -1.1553516071072405]
If you want to obtain samples in transformed space, you can access result_fr.ldf to obtain the LogDensityFunction used for inference. Calling rand(result_fr.ldf) will return a vector of samples in the transformed space. However, note that these vectors are harder to interpret and work with.
Each sample is a DynamicPPL.VarNamedTuple, and a matrix of these can be converted into a Chains object.
z = rand(result_fr, 100_000, 1)
using AbstractMCMC: AbstractMCMC
vi_chain = AbstractMCMC.from_samples(MCMCChains.Chains, z)Chains MCMC chain (100000×12×1 Array{Float64, 3}):
Iterations = 1:1:100000
Number of chains = 1
Samples per chain = 100000
parameters = σ², intercept, coefficients[1], coefficients[2], coefficients[3], coefficients[4], coefficients[5], coefficients[6], coefficients[7], coefficients[8], coefficients[9], coefficients[10]
internals =
Use `describe(chains)` for summary statistics and quantiles.
Now, we can, for example, look at expectations:
describe(vi_chain)Chains MCMC chain (100000×12×1 Array{Float64, 3}):
Iterations = 1:1:100000
Number of chains = 1
Samples per chain = 100000
parameters = σ², intercept, coefficients[1], coefficients[2], coefficients[3], coefficients[4], coefficients[5], coefficients[6], coefficients[7], coefficients[8], coefficients[9], coefficients[10]
internals =
Summary Statistics
parameters mean std mcse ess_bulk ess_tail ⋯
Symbol Float64 Float64 Float64 Float64 Float64 ⋯
σ² 0.2979 0.0944 0.0003 99377.6822 96582.0412 ⋯
intercept -0.0016 0.1125 0.0004 100156.7778 98715.2905 ⋯
coefficients[1] 0.3672 0.4919 0.0016 99095.4982 99338.7695 ⋯
coefficients[2] -0.0848 0.5018 0.0016 98983.5046 98888.4122 ⋯
coefficients[3] -0.0768 0.3936 0.0012 99138.3295 99291.6470 ⋯
coefficients[4] 0.6037 0.3461 0.0011 99701.6805 99584.8467 ⋯
coefficients[5] -0.0276 0.5159 0.0016 99629.0600 97011.6786 ⋯
coefficients[6] 0.0958 0.3022 0.0010 100088.3743 99418.5105 ⋯
coefficients[7] -0.0633 0.3078 0.0010 99329.5193 99842.4190 ⋯
coefficients[8] 0.1278 0.2663 0.0008 99707.9570 99166.1583 ⋯
coefficients[9] 0.1865 0.3432 0.0011 99439.5319 97843.6277 ⋯
coefficients[10] -0.5910 0.3832 0.0012 99307.2788 97288.4772 ⋯
2 columns omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
σ² 0.1549 0.2306 0.2840 0.3493 0.5218
intercept -0.2221 -0.0778 -0.0016 0.0738 0.2195
coefficients[1] -0.5953 0.0342 0.3669 0.6980 1.3360
coefficients[2] -1.0711 -0.4237 -0.0867 0.2552 0.9033
coefficients[3] -0.8473 -0.3408 -0.0780 0.1889 0.6989
coefficients[4] -0.0748 0.3697 0.6046 0.8365 1.2790
coefficients[5] -1.0342 -0.3762 -0.0269 0.3208 0.9809
coefficients[6] -0.4962 -0.1075 0.0964 0.2992 0.6889
coefficients[7] -0.6632 -0.2728 -0.0634 0.1433 0.5424
coefficients[8] -0.3943 -0.0526 0.1288 0.3071 0.6481
coefficients[9] -0.4886 -0.0454 0.1864 0.4182 0.8606
coefficients[10] -1.3439 -0.8483 -0.5907 -0.3341 0.1617
(Since we’re drawing independent samples, we can simply ignore the ESS and Rhat metrics.)
Let’s compare this against samples from NUTS:
mcmc_chain = sample(m, NUTS(), 10_000; progress=false);
vi_mean = mean(vi_chain)[:, 2]
mcmc_mean = mean(mcmc_chain, names(mcmc_chain, :parameters))[:, 2]
plot(mcmc_mean; xticks=1:1:length(mcmc_mean), label="mean of NUTS")
plot!(vi_mean; label="mean of VI")┌ Info: Found initial step size └ ϵ = 0.4 ┌ Warning: There were 31 divergent transitions. Consider reparameterising your model or using a smaller step size. For adaptive samplers such as NUTS and HMCDA, consider increasing `target_accept`. └ @ Turing.Inference ~/.julia/packages/Turing/tbwJL/src/mcmc/hmc.jl:481
That looks pretty good! But let’s see how the predictive distributions looks for the two.
Making Predictions
Similarily to the linear regression tutorial, we’re going to compare to multivariate ordinary linear regression using the GLM package:
# Import the GLM package.
using GLM
# Perform multivariate OLS.
ols = lm(
@formula(MPG ~ Cyl + Disp + HP + DRat + WT + QSec + VS + AM + Gear + Carb), train_cut
)
# Store our predictions in the original dataframe.
train_cut.OLSPrediction = unstandardize(GLM.predict(ols), train_unstandardized.MPG)
test_cut.OLSPrediction = unstandardize(GLM.predict(ols, test_cut), train_unstandardized.MPG);# Make a prediction given an input vector, using mean parameter values from a chain.
function prediction(chain, x)
p = get_params(chain)
α = mean(p.intercept)
β = collect(mean.(p.coefficients))
return α .+ x * β
endprediction (generic function with 1 method)
# Unstandardize the dependent variable.
train_cut.MPG = unstandardize(train_cut.MPG, train_unstandardized.MPG)
test_cut.MPG = unstandardize(test_cut.MPG, train_unstandardized.MPG);# Show the first side rows of the modified dataframe.
first(test_cut, 6)| Row | MPG | Cyl | Disp | HP | DRat | WT | QSec | VS | AM | Gear | Carb | OLSPrediction |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | |
| 1 | 15.2 | 1.04746 | 0.565102 | 0.258882 | -0.652405 | 0.0714991 | -0.716725 | -0.977008 | -0.598293 | -0.891883 | -0.469126 | 19.8583 |
| 2 | 13.3 | 1.04746 | 0.929057 | 1.90345 | 0.380435 | 0.465717 | -1.90403 | -0.977008 | -0.598293 | -0.891883 | 1.11869 | 16.0462 |
| 3 | 19.2 | 1.04746 | 1.32466 | 0.691663 | -0.777058 | 0.470584 | -0.873777 | -0.977008 | -0.598293 | -0.891883 | -0.469126 | 18.5746 |
| 4 | 27.3 | -1.25696 | -1.21511 | -1.19526 | 1.0037 | -1.38857 | 0.288403 | 0.977008 | 1.59545 | 1.07026 | -1.26303 | 29.3233 |
| 5 | 26.0 | -1.25696 | -0.888346 | -0.762482 | 1.62697 | -1.18903 | -1.09365 | -0.977008 | 1.59545 | 3.0324 | -0.469126 | 30.7731 |
| 6 | 30.4 | -1.25696 | -1.08773 | -0.381634 | 0.451665 | -1.79933 | -0.968007 | 0.977008 | 1.59545 | 3.0324 | -0.469126 | 25.2892 |
# Construct the Chains from the Variational Approximations
vi_mf_chain = AbstractMCMC.from_samples(MCMCChains.Chains, rand(result_mf, 10_000, 1))
vi_fr_chain = AbstractMCMC.from_samples(MCMCChains.Chains, rand(result_fr, 10_000, 1))Chains MCMC chain (10000×12×1 Array{Float64, 3}):
Iterations = 1:1:10000
Number of chains = 1
Samples per chain = 10000
parameters = σ², intercept, coefficients[1], coefficients[2], coefficients[3], coefficients[4], coefficients[5], coefficients[6], coefficients[7], coefficients[8], coefficients[9], coefficients[10]
internals =
Use `describe(chains)` for summary statistics and quantiles.
# Calculate the predictions for the training and testing sets using the samples `z` from variational posterior
train_cut.VIMFPredictions = unstandardize(
prediction(vi_mf_chain, train), train_unstandardized.MPG
)
test_cut.VIMFPredictions = unstandardize(
prediction(vi_mf_chain, test), train_unstandardized.MPG
)
train_cut.VIFRPredictions = unstandardize(
prediction(vi_fr_chain, train), train_unstandardized.MPG
)
test_cut.VIFRPredictions = unstandardize(
prediction(vi_fr_chain, test), train_unstandardized.MPG
)
train_cut.BayesPredictions = unstandardize(
prediction(mcmc_chain, train), train_unstandardized.MPG
)
test_cut.BayesPredictions = unstandardize(
prediction(mcmc_chain, test), train_unstandardized.MPG
);vi_mf_loss1 = mean((train_cut.VIMFPredictions - train_cut.MPG) .^ 2)
vi_fr_loss1 = mean((train_cut.VIFRPredictions - train_cut.MPG) .^ 2)
bayes_loss1 = mean((train_cut.BayesPredictions - train_cut.MPG) .^ 2)
ols_loss1 = mean((train_cut.OLSPrediction - train_cut.MPG) .^ 2)
vi_mf_loss2 = mean((test_cut.VIMFPredictions - test_cut.MPG) .^ 2)
vi_fr_loss2 = mean((test_cut.VIFRPredictions - test_cut.MPG) .^ 2)
bayes_loss2 = mean((test_cut.BayesPredictions - test_cut.MPG) .^ 2)
ols_loss2 = mean((test_cut.OLSPrediction - test_cut.MPG) .^ 2)
println("Training set:
VI Mean-Field loss: $vi_mf_loss1
VI Full-Rank loss: $vi_fr_loss1
Bayes loss: $bayes_loss1
OLS loss: $ols_loss1
Test set:
VI Mean-Field loss: $vi_mf_loss2
VI Full-Rank loss: $vi_fr_loss2
Bayes loss: $bayes_loss2
OLS loss: $ols_loss2")Training set:
VI Mean-Field loss: 3.075162701731674
VI Full-Rank loss: 3.079297574755324
Bayes loss: 3.0730877712644475
OLS loss: 3.07092612489301
Test set:
VI Mean-Field loss: 25.969505420207565
VI Full-Rank loss: 24.952749829168333
Bayes loss: 25.93074787116897
OLS loss: 27.094813070760562
Interestingly the squared difference between true- and mean-prediction on the test-set is actually better for the full-rank variational posterior than for the “true” posterior obtained by MCMC sampling using NUTS. But, as Bayesians, we know that the mean doesn’t tell the entire story. One quick check is to look at the mean predictions ± standard deviation of the two different approaches:
preds_vi_mf = mapreduce(hcat, 1:5:size(vi_mf_chain, 1)) do i
return unstandardize(prediction(vi_mf_chain[i], test), train_unstandardized.MPG)
end
p1 = scatter(
1:size(test, 1),
mean(preds_vi_mf; dims=2);
yerr=std(preds_vi_mf; dims=2),
label="prediction (mean ± std)",
size=(900, 500),
markersize=8,
)
scatter!(1:size(test, 1), unstandardize(test_label, train_unstandardized.MPG); label="true")
xaxis!(1:size(test, 1))
ylims!(10, 40)
title!("VI Mean-Field")
preds_vi_fr = mapreduce(hcat, 1:5:size(vi_mf_chain, 1)) do i
return unstandardize(prediction(vi_fr_chain[i], test), train_unstandardized.MPG)
end
p2 = scatter(
1:size(test, 1),
mean(preds_vi_fr; dims=2);
yerr=std(preds_vi_fr; dims=2),
label="prediction (mean ± std)",
size=(900, 500),
markersize=8,
)
scatter!(1:size(test, 1), unstandardize(test_label, train_unstandardized.MPG); label="true")
xaxis!(1:size(test, 1))
ylims!(10, 40)
title!("VI Full-Rank")
preds_mcmc = mapreduce(hcat, 1:5:size(mcmc_chain, 1)) do i
return unstandardize(prediction(mcmc_chain[i], test), train_unstandardized.MPG)
end
p3 = scatter(
1:size(test, 1),
mean(preds_mcmc; dims=2);
yerr=std(preds_mcmc; dims=2),
label="prediction (mean ± std)",
size=(900, 500),
markersize=8,
)
scatter!(1:size(test, 1), unstandardize(test_label, train_unstandardized.MPG); label="true")
xaxis!(1:size(test, 1))
ylims!(10, 40)
title!("MCMC (NUTS)")
plot(p1, p2, p3; layout=(1, 3), size=(900, 250), label="")We can see that the full-rank VI approximation is very close to the predictions from MCMC samples. Also, the coverage of full-rank VI and MCMC is much better the crude mean-field approximation.
Footnotes
Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of Machine Learning Research, 18(14).↩︎
Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Ranganath, R., Gerrish, S., & Blei, D. (2014). Black box variational inference. In Proceedings of the International Conference on Artificial intelligence and statistics. PMLR.↩︎
Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In Proceedings of the International Conference on Learning Representations.↩︎
Rezende, D. J., Mohamed, S., & Wierstra, D (2014). Stochastic backpropagation and approximate inference in deep generative models. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of Machine Learning Research, 18(14).↩︎
Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Ranganath, R., Gerrish, S., & Blei, D. (2014). Black box variational inference. In Proceedings of the International Conference on Artificial intelligence and statistics. PMLR.↩︎
Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In Proceedings of the International Conference on Learning Representations.↩︎
Rezende, D. J., Mohamed, S., & Wierstra, D (2014). Stochastic backpropagation and approximate inference in deep generative models. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Khaled, A., Mishchenko, K., & Jin, C. (2023). DoWG unleashed: An efficient universal parameter-free gradient descent method. In Advances in Neural Information Processing Systems, 36.↩︎
Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Domke, J., Gower, R., & Garrigos, G. (2023). Provable convergence guarantees for black-box variational inference. In Advances in Neural Information Processing Systems, 36.↩︎
Kim, K., Oh, J., Wu, K., Ma, Y., & Gardner, J. (2023). On the convergence of black-box variational inference. In Advances in Neural Information Processing Systems, 36.↩︎
Shamir, O., & Zhang, T. (2013). Stochastic gradient descent for non-smooth optimisation: Convergence results and optimal averaging schemes. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Khaled, A., Mishchenko, K., & Jin, C. (2023). DoWG unleashed: An efficient universal parameter-free gradient descent method. In Advances in Neural Information Processing Systems, 36.↩︎
Ivgi, M., Hinder, O., & Carmon, Y. (2023). DoG is SGD’s best friend: A parameter-free dynamic step size schedule. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Kingma, D. P., & Ba, J. (2015). Adam: A method for stochastic optimisation. In Proceedings of the International Conference on Learning Representations.↩︎
Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of Machine Learning Research, 18(14).↩︎