= model(data...) # instantiate model on the data
m = q_fullrank_gaussian(m) # initial variational approximation
q_init vi(m, q_init, 1000) # perform VI with the default algorithm on `m` for 1000 iterations
Variational Inference
This post will look at variational inference (VI), an optimization approach to approximate Bayesian inference, and how to use it in Turing.jl as an alternative to other approaches such as MCMC. This post will focus on the usage of VI in Turing rather than the principles and theory underlying VI. If you are interested in understanding the mathematics you can checkout our write-up or any other resource online (there are a lot of great ones).
Let’s start with a minimal example. Consider a Turing.Model
, which we denote as model
. Approximating the posterior associated with model
via VI is as simple as
Thus, it’s no more work than standard MCMC sampling in Turing. The default algorithm uses stochastic gradient descent to minimize the (exclusive) KL divergence. This is commonly referred to as automatic differentiation variational inference1, stochastic gradient VI2, and black-box variational inference3 with the reparameterization gradient456.
To get a bit more into what we can do with VI, let’s look at a more concrete example. We will reproduce the tutorial on Bayesian linear regression using VI instead of MCMC. After that, we will discuss how to customize the behavior of vi
for more advanced usage.
Let’s first import the relevant packages:
using Random
using Turing
using Turing: Variational
using AdvancedVI
using Plots
Random.seed!(42);
Bayesian Linear Regression Example
Let’s start by setting up our example. We will re-use the Bayesian linear regression example. As we’ll see, there is really no additional work required to apply variational inference to a more complex Model
.
using FillArrays
using RDatasets
using LinearAlgebra
# Import the "Default" dataset.
= RDatasets.dataset("datasets", "mtcars");
data
# Show the first six rows of the dataset.
first(data, 6)
Row | Model | MPG | Cyl | Disp | HP | DRat | WT | QSec | VS | AM | Gear | Carb |
---|---|---|---|---|---|---|---|---|---|---|---|---|
String31 | Float64 | Int64 | Float64 | Int64 | Float64 | Float64 | Float64 | Int64 | Int64 | Int64 | Int64 | |
1 | Mazda RX4 | 21.0 | 6 | 160.0 | 110 | 3.9 | 2.62 | 16.46 | 0 | 1 | 4 | 4 |
2 | Mazda RX4 Wag | 21.0 | 6 | 160.0 | 110 | 3.9 | 2.875 | 17.02 | 0 | 1 | 4 | 4 |
3 | Datsun 710 | 22.8 | 4 | 108.0 | 93 | 3.85 | 2.32 | 18.61 | 1 | 1 | 4 | 1 |
4 | Hornet 4 Drive | 21.4 | 6 | 258.0 | 110 | 3.08 | 3.215 | 19.44 | 1 | 0 | 3 | 1 |
5 | Hornet Sportabout | 18.7 | 8 | 360.0 | 175 | 3.15 | 3.44 | 17.02 | 0 | 0 | 3 | 2 |
6 | Valiant | 18.1 | 6 | 225.0 | 105 | 2.76 | 3.46 | 20.22 | 1 | 0 | 3 | 1 |
# Function to split samples.
function split_data(df, at=0.70)
= size(df, 1)
r = Int(round(r * at))
index = df[1:index, :]
train = df[(index + 1):end, :]
test return train, test
end
# A handy helper function to rescale our dataset.
function standardize(x)
return (x .- mean(x; dims=1)) ./ std(x; dims=1)
end
function standardize(x, orig)
return (x .- mean(orig; dims=1)) ./ std(orig; dims=1)
end
# Another helper function to unstandardize our datasets.
function unstandardize(x, orig)
return x .* std(orig; dims=1) .+ mean(orig; dims=1)
end
function unstandardize(x, mean_train, std_train)
return x .* std_train .+ mean_train
end
unstandardize (generic function with 2 methods)
# Remove the model column.
select!(data, Not(:Model))
# Split our dataset 70%/30% into training/test sets.
= split_data(data, 0.7)
train, test = copy(train)
train_unstandardized
# Standardize both datasets.
= standardize(Matrix(train))
std_train = standardize(Matrix(test), Matrix(train))
std_test
# Save dataframe versions of our dataset.
= DataFrame(std_train, names(data))
train_cut = DataFrame(std_test, names(data))
test_cut
# Create our labels. These are the values we are trying to predict.
= train_cut[:, :MPG]
train_label = test_cut[:, :MPG]
test_label
# Get the list of columns to keep.
= filter(x -> !in(x, ["MPG"]), names(data))
remove_names
# Filter the test and train sets.
= Matrix(train_cut[:, remove_names]);
train = Matrix(test_cut[:, remove_names]); test
# Bayesian linear regression.
@model function linear_regression(x, y, n_obs, n_vars, ::Type{T}=Vector{Float64}) where {T}
# Set variance prior.
~ truncated(Normal(0, 100); lower=0)
σ²
# Set intercept prior.
~ Normal(0, 3)
intercept
# Set the priors on our coefficients.
~ MvNormal(Zeros(n_vars), 10.0 * I)
coefficients
# Calculate all the mu terms.
= intercept .+ x * coefficients
mu return y ~ MvNormal(mu, σ² * I)
end;
= size(train)
n_obs, n_vars = linear_regression(train, train_label, n_obs, n_vars); m
Basic Usage
To run VI, we must first set a variational family. For instance, the most commonly used family is the mean-field Gaussian family. For this, Turing provides functions that automatically construct the initialization corresponding to the model m
:
= q_meanfield_gaussian(m); q_init
vi
will automatically recognize the variational family through the type of q_init
. Here is a detailed documentation for the constructor:
@doc(Variational.q_meanfield_gaussian)
q_meanfield_gaussian(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
location::Union{Nothing,<:AbstractVector} = nothing,
scale::Union{Nothing,<:Diagonal} = nothing,
kwargs...
)
Find a numerically non-degenerate mean-field Gaussian q
for approximating the target model
.
Arguments
model
: The targetDynamicPPL.Model
.
Keyword Arguments
location
: The location parameter of the initialization. Ifnothing
, a vector of zeros is used.scale
: The scale parameter of the initialization. Ifnothing
, an identity matrix is used.
The remaining keyword arguments are passed to q_locationscale
.
Returns
q::Bijectors.TransformedDistribution
: AAdvancedVI.LocationScale
distribution matching the support ofmodel
.
As we can see, the precise initialization can be customized through the keyword arguments.
Let’s run VI with the default setting:
= 1000
n_iters = vi(m, q_init, n_iters; show_progress=false); q_avg, q_last, info, state
The default setting uses the AdvancedVI.RepGradELBO
objective, which corresponds to a variant of what is known as automatic differentiation VI7 or stochastic gradient VI8 or black-box VI9 with the reparameterization gradient101112. The default optimizer we use is AdvancedVI.DoWG
13 combined with a proximal operator. (The use of proximal operators with VI on a location-scale family is discussed in detail by J. Domke1415 and others16.) We will take a deeper look into the returned values and the keyword arguments in the following subsections. First, here is the full documentation for vi
:
@doc(Variational.vi)
vi(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
q,
n_iterations::Int;
objective::AdvancedVI.AbstractVariationalObjective = AdvancedVI.RepGradELBO(
10; entropy = AdvancedVI.ClosedFormEntropyZeroGradient()
),
show_progress::Bool = Turing.PROGRESS[],
optimizer::Optimisers.AbstractRule = AdvancedVI.DoWG(),
averager::AdvancedVI.AbstractAverager = AdvancedVI.PolynomialAveraging(),
operator::AdvancedVI.AbstractOperator = AdvancedVI.ProximalLocationScaleEntropy(),
adtype::ADTypes.AbstractADType = Turing.DEFAULT_ADTYPE,
kwargs...
)
Approximating the target model
via variational inference by optimizing objective
with the initialization q
. This is a thin wrapper around AdvancedVI.optimize
.
Arguments
model
: The targetDynamicPPL.Model
.q
: The initial variational approximation.n_iterations
: Number of optimization steps.
Keyword Arguments
objective
: Variational objective to be optimized.show_progress
: Whether to show the progress bar.optimizer
: Optimization algorithm.averager
: Parameter averaging strategy.operator
: Operator applied after each optimization step.adtype
: Automatic differentiation backend.
See the docs of AdvancedVI.optimize
for additional keyword arguments.
Returns
q
: Variational distribution formed by the last iterate of the optimization run.q_avg
: Variational distribution formed by the averaged iterates according toaverager
.state
: Collection of states used for optimization. This can be used to resume from a past call tovi
.info
: Information generated during the optimization run.
Values Returned by vi
The main output of the algorithm is q_avg
, the average of the parameters generated by the optimization algorithm. For computing q_avg
, the default setting uses what is known as polynomial averaging17. Usually, q_avg
will perform better than the last-iterate q_last
. For instance, we can compare the ELBO of the two:
@info("Objective of q_avg and q_last",
= estimate_objective(AdvancedVI.RepGradELBO(32), q_avg, Turing.Variational.make_logdensity(m)),
ELBO_q_avg = estimate_objective(AdvancedVI.RepGradELBO(32), q_last, Turing.Variational.make_logdensity(m))
ELBO_q_last )
┌ Info: Objective of q_avg and q_last │ ELBO_q_avg = -52.8918031125194 └ ELBO_q_last = -54.027615163157115
We can see that ELBO_q_avg
is slightly more optimal.
Now, info
contains information generated during optimization that could be useful for diagnostics. For the default setting, which is RepGradELBO
, it contains the ELBO estimated at each step, which can be plotted as follows:
plot([i.elbo for i in info], xlabel="Iterations", ylabel="ELBO", label="info") Plots.
Since the ELBO is estimated by a small number of samples, it appears noisy. Furthermore, at each step, the ELBO is evaluated on q_last
, not q_avg
, which is the actual output that we care about. To obtain more accurate ELBO estimates evaluated on q_avg
, we have to define a custom callback function.
Custom Callback Functions
To inspect the progress of optimization in more detail, one can define a custom callback function. For example, the following callback function estimates the ELBO on q_avg
every 10 steps with a larger number of samples:
function callback(; stat, averaged_params, restructure, kwargs...)
if mod(stat.iteration, 10) == 1
= restructure(averaged_params)
q_avg = AdvancedVI.RepGradELBO(128)
obj = estimate_objective(obj, q_avg, Turing.Variational.make_logdensity(m))
elbo_avg = elbo_avg,)
(elbo_avg else
nothing
end
end;
The NamedTuple
returned by callback
will be appended to the corresponding entry of info
, and it will also be displayed on the progress meter if show_progress
is set as true
.
The custom callback can be supplied to vi
as a keyword argument:
= vi(m, q_init, n_iters; show_progress=false, callback=callback); q_mf, _, info_mf, _
Let’s plot the result:
= 1:10:length(info_mf)
iters = [i.elbo_avg for i in info_mf[iters]]
elbo_mf plot!(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="callback", ylims=(-200,Inf)) Plots.
We can see that the ELBO values are less noisy and progress more smoothly due to averaging.
Using Different Optimisers
The default optimiser we use is a proximal variant of DoWG18. For Gaussian variational families, this works well as a default option. Sometimes, the step size of AdvancedVI.DoWG
could be too large, resulting in unstable behavior. (In this case, we recommend trying AdvancedVI.DoG
19) Or, for whatever reason, it might be desirable to use a different optimiser. Our implementation supports any optimiser that implements the Optimisers.jl interface.
For instance, let’s try using Optimisers.Adam
20, which is a popular choice. Since AdvancedVI
does not implement a proximal operator for Optimisers.Adam
, we must use the AdvancedVI.ClipScale()
projection operator, which ensures that the scale matrix of the variational approximation is positive definite. (See the paper by J. Domke 202021 for more detail about the use of a projection operator.)
using Optimisers
= vi(m, q_init, n_iters; show_progress=false, callback=callback, optimizer=Optimisers.Adam(3e-3), operator=ClipScale()); _, _, info_adam, _
= 1:10:length(info_mf)
iters = [i.elbo_avg for i in info_adam[iters]]
elbo_adam plot(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="DoWG")
Plots.plot!(iters, elbo_adam, xlabel="Iterations", ylabel="ELBO", label="Adam") Plots.
Compared to the default option AdvancedVI.DoWG()
, we can see that Optimisers.Adam(3e-3)
is converging more slowly. With more step size tuning, it is possible that Optimisers.Adam
could perform better or equal. That is, most common optimisers require some degree of tuning to perform better or comparably to AdvancedVI.DoWG()
or AdvancedVI.DoG()
, which do not require much tuning at all. Due to this fact, they are referred to as parameter-free optimizers.
Using Full-Rank Variational Families
So far, we have only used the mean-field Gaussian family. This, however, approximates the posterior covariance with a diagonal matrix. To model the full covariance matrix, we can use the full-rank Gaussian family2223:
= q_fullrank_gaussian(m); q_init_fr
@doc(Variational.q_fullrank_gaussian)
q_fullrank_gaussian(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
location::Union{Nothing,<:AbstractVector} = nothing,
scale::Union{Nothing,<:LowerTriangular} = nothing,
kwargs...
)
Find a numerically non-degenerate Gaussian q
with a scale with full-rank factors (traditionally referred to as a "full-rank family") for approximating the target model
.
Arguments
model
: The targetDynamicPPL.Model
.
Keyword Arguments
location
: The location parameter of the initialization. Ifnothing
, a vector of zeros is used.scale
: The scale parameter of the initialization. Ifnothing
, an identity matrix is used.
The remaining keyword arguments are passed to q_locationscale
.
Returns
q::Bijectors.TransformedDistribution
: AAdvancedVI.LocationScale
distribution matching the support ofmodel
.
The term full-rank might seem a bit peculiar since covariance matrices are always full-rank. This term, however, traditionally comes from the fact that full-rank families use full-rank factors in addition to the diagonal of the covariance.
In contrast to the mean-field family, the full-rank family will often result in more computation per optimization step and slower convergence, especially in high dimensions:
= vi(m, q_init_fr, n_iters; show_progress=false, callback)
q_fr, _, info_fr, _
plot(elbo_mf, xlabel="Iterations", ylabel="ELBO", label="Mean-Field", ylims=(-200, Inf))
Plots.
= [i.elbo_avg for i in info_fr[iters]]
elbo_fr plot!(elbo_fr, xlabel="Iterations", ylabel="ELBO", label="Full-Rank", ylims=(-200, Inf)) Plots.
However, we can see that the full-rank families achieve a higher ELBO in the end. Due to the relationship between the ELBO and the Kullback-Leibler divergence, this indicates that the full-rank covariance is much more accurate. This trade-off between statistical accuracy and optimization speed is often referred to as the statistical-computational trade-off. The fact that we can control this trade-off through the choice of variational family is a strength, rather than a limitation, of variational inference.
We can also visualize the covariance matrix.
heatmap(cov(rand(q_fr, 100_000), dims=2))
Obtaining Summary Statistics
Let’s inspect the resulting variational approximation in more detail and compare it against MCMC. To obtain summary statistics from VI, we can draw samples from the resulting variational approximation:
= rand(q_fr, 100_000); z
Now, we can, for example, look at expectations:
= vec(mean(z; dims=2)) avg
12-element Vector{Float64}:
0.381370039029503
-0.002712666930547315
0.35993913738443195
-0.07407809361373849
-0.09185665913131266
0.5861630697109698
-0.03587845396233794
0.08657968704192678
-0.0748945529831772
0.118774737727532
0.19056418649105789
-0.5957207566979493
The vector has the same ordering as the parameters in the model, e.g. in this case σ²
has index 1
, intercept
has index 2
and coefficients
has indices 3:12
. If you forget or you might want to do something programmatically with the result, you can obtain the sym → indices
mapping as follows:
using Bijectors: bijector
= bijector(m, Val(true));
_, sym2range sym2range
(intercept = UnitRange{Int64}[2:2], σ² = UnitRange{Int64}[1:1], coefficients = UnitRange{Int64}[3:12])
For example, we can check the sample distribution and mean value of σ²
:
histogram(z[1, :])
union(sym2range[:σ²]...)] avg[
1-element Vector{Float64}:
0.381370039029503
union(sym2range[:intercept]...)] avg[
1-element Vector{Float64}:
-0.002712666930547315
union(sym2range[:coefficients]...)] avg[
10-element Vector{Float64}:
0.35993913738443195
-0.07407809361373849
-0.09185665913131266
0.5861630697109698
-0.03587845396233794
0.08657968704192678
-0.0748945529831772
0.118774737727532
0.19056418649105789
-0.5957207566979493
For further convenience, we can wrap the samples into a Chains
object to summarize the results.
= Turing.DynamicPPL.VarInfo(m)
varinf = Turing.DynamicPPL.varname_and_value_leaves(Turing.DynamicPPL.values_as(varinf, OrderedDict))
vns_and_values = map(first, vns_and_values)
varnames = Chains(reshape(z', (size(z,2), size(z,1), 1)), varnames) vi_chain
Chains MCMC chain (100000×12×1 reshape(adjoint(::Matrix{Float64}), 100000, 12, 1) with eltype Float64): Iterations = 1:1:100000 Number of chains = 1 Samples per chain = 100000 parameters = σ², intercept, coefficients[1], coefficients[2], coefficients[3], coefficients[4], coefficients[5], coefficients[6], coefficients[7], coefficients[8], coefficients[9], coefficients[10] Summary Statistics parameters mean std mcse ess_bulk ess_tail ⋯ Symbol Float64 Float64 Float64 Float64 Float64 ⋯ σ² 0.3814 0.1228 0.0004 100503.1171 99052.7657 ⋯ intercept -0.0027 0.1258 0.0004 97521.0892 100384.5892 ⋯ coefficients[1] 0.3599 0.5459 0.0017 99608.2093 99087.2832 ⋯ coefficients[2] -0.0741 0.5732 0.0018 101491.7677 99635.6033 ⋯ coefficients[3] -0.0919 0.4413 0.0014 99361.0380 99530.5701 ⋯ coefficients[4] 0.5862 0.3884 0.0012 100263.9999 98920.2938 ⋯ coefficients[5] -0.0359 0.5759 0.0018 100576.2245 100011.1552 ⋯ coefficients[6] 0.0866 0.3394 0.0011 100945.4015 99134.7829 ⋯ coefficients[7] -0.0749 0.3511 0.0011 98955.0004 100048.3451 ⋯ coefficients[8] 0.1188 0.3008 0.0010 99629.3612 97844.1731 ⋯ coefficients[9] 0.1906 0.3904 0.0012 99034.8858 99383.9414 ⋯ coefficients[10] -0.5957 0.4265 0.0014 98214.9128 99085.9496 ⋯ 2 columns omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 σ² 0.1963 0.2938 0.3630 0.4482 0.6728 intercept -0.2498 -0.0875 -0.0027 0.0820 0.2434 coefficients[1] -0.7092 -0.0075 0.3605 0.7285 1.4319 coefficients[2] -1.1957 -0.4617 -0.0740 0.3114 1.0545 coefficients[3] -0.9509 -0.3928 -0.0924 0.2074 0.7723 coefficients[4] -0.1709 0.3241 0.5851 0.8481 1.3485 coefficients[5] -1.1703 -0.4242 -0.0332 0.3553 1.0852 coefficients[6] -0.5791 -0.1415 0.0859 0.3163 0.7509 coefficients[7] -0.7637 -0.3121 -0.0750 0.1604 0.6154 coefficients[8] -0.4695 -0.0830 0.1178 0.3221 0.7105 coefficients[9] -0.5736 -0.0731 0.1904 0.4525 0.9573 coefficients[10] -1.4319 -0.8837 -0.5950 -0.3082 0.2415
(Since we’re drawing independent samples, we can simply ignore the ESS and Rhat metrics.) Unfortunately, extracting varnames
is a bit verbose at the moment, but hopefully will become simpler in the near future.
Let’s compare this against samples from NUTS
:
= sample(m, NUTS(), 10_000, drop_warmup=true, progress=false);
mcmc_chain
= mean(vi_chain)[:, 2]
vi_mean = mean(mcmc_chain, names(mcmc_chain, :parameters))[:, 2]
mcmc_mean
plot(mcmc_mean; xticks=1:1:length(mcmc_mean), label="mean of NUTS")
plot!(vi_mean; label="mean of VI")
┌ Info: Found initial step size └ ϵ = 0.05
That looks pretty good! But let’s see how the predictive distributions looks for the two.
Making Predictions
Similarily to the linear regression tutorial, we’re going to compare to multivariate ordinary linear regression using the GLM
package:
# Import the GLM package.
using GLM
# Perform multivariate OLS.
= lm(
ols @formula(MPG ~ Cyl + Disp + HP + DRat + WT + QSec + VS + AM + Gear + Carb), train_cut
)
# Store our predictions in the original dataframe.
= unstandardize(GLM.predict(ols), train_unstandardized.MPG)
train_cut.OLSPrediction = unstandardize(GLM.predict(ols, test_cut), train_unstandardized.MPG); test_cut.OLSPrediction
# Make a prediction given an input vector, using mean parameter values from a chain.
function prediction(chain, x)
= get_params(chain)
p = mean(p.intercept)
α = collect(mean.(p.coefficients))
β return α .+ x * β
end
prediction (generic function with 1 method)
# Unstandardize the dependent variable.
= unstandardize(train_cut.MPG, train_unstandardized.MPG)
train_cut.MPG = unstandardize(test_cut.MPG, train_unstandardized.MPG); test_cut.MPG
# Show the first side rows of the modified dataframe.
first(test_cut, 6)
Row | MPG | Cyl | Disp | HP | DRat | WT | QSec | VS | AM | Gear | Carb | OLSPrediction |
---|---|---|---|---|---|---|---|---|---|---|---|---|
Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | |
1 | 15.2 | 1.04746 | 0.565102 | 0.258882 | -0.652405 | 0.0714991 | -0.716725 | -0.977008 | -0.598293 | -0.891883 | -0.469126 | 19.8583 |
2 | 13.3 | 1.04746 | 0.929057 | 1.90345 | 0.380435 | 0.465717 | -1.90403 | -0.977008 | -0.598293 | -0.891883 | 1.11869 | 16.0462 |
3 | 19.2 | 1.04746 | 1.32466 | 0.691663 | -0.777058 | 0.470584 | -0.873777 | -0.977008 | -0.598293 | -0.891883 | -0.469126 | 18.5746 |
4 | 27.3 | -1.25696 | -1.21511 | -1.19526 | 1.0037 | -1.38857 | 0.288403 | 0.977008 | 1.59545 | 1.07026 | -1.26303 | 29.3233 |
5 | 26.0 | -1.25696 | -0.888346 | -0.762482 | 1.62697 | -1.18903 | -1.09365 | -0.977008 | 1.59545 | 3.0324 | -0.469126 | 30.7731 |
6 | 30.4 | -1.25696 | -1.08773 | -0.381634 | 0.451665 | -1.79933 | -0.968007 | 0.977008 | 1.59545 | 3.0324 | -0.469126 | 25.2892 |
# Construct the Chains from the Variational Approximations
= rand(q_mf, 10_000);
z_mf = rand(q_fr, 10_000);
z_fr
= Chains(reshape(z_mf', (size(z_mf,2), size(z_mf,1), 1)), varnames);
vi_mf_chain = Chains(reshape(z_fr', (size(z_fr,2), size(z_fr,1), 1)), varnames); vi_fr_chain
# Calculate the predictions for the training and testing sets using the samples `z` from variational posterior
= unstandardize(
train_cut.VIMFPredictions prediction(vi_mf_chain, train), train_unstandardized.MPG
)= unstandardize(
test_cut.VIMFPredictions prediction(vi_mf_chain, test), train_unstandardized.MPG
)
= unstandardize(
train_cut.VIFRPredictions prediction(vi_fr_chain, train), train_unstandardized.MPG
)= unstandardize(
test_cut.VIFRPredictions prediction(vi_fr_chain, test), train_unstandardized.MPG
)
= unstandardize(
train_cut.BayesPredictions prediction(mcmc_chain, train), train_unstandardized.MPG
)= unstandardize(
test_cut.BayesPredictions prediction(mcmc_chain, test), train_unstandardized.MPG
);
= mean((train_cut.VIMFPredictions - train_cut.MPG) .^ 2)
vi_mf_loss1 = mean((train_cut.VIFRPredictions - train_cut.MPG) .^ 2)
vi_fr_loss1 = mean((train_cut.BayesPredictions - train_cut.MPG) .^ 2)
bayes_loss1 = mean((train_cut.OLSPrediction - train_cut.MPG) .^ 2)
ols_loss1
= mean((test_cut.VIMFPredictions - test_cut.MPG) .^ 2)
vi_mf_loss2 = mean((test_cut.VIFRPredictions - test_cut.MPG) .^ 2)
vi_fr_loss2 = mean((test_cut.BayesPredictions - test_cut.MPG) .^ 2)
bayes_loss2 = mean((test_cut.OLSPrediction - test_cut.MPG) .^ 2)
ols_loss2
println("Training set:
-Field loss: $vi_mf_loss1
VI Mean-Rank loss: $vi_fr_loss1
VI Full: $bayes_loss1
Bayes loss: $ols_loss1
OLS lossTest set:
-Field loss: $vi_mf_loss2
VI Mean-Rank loss: $vi_fr_loss2
VI Full: $bayes_loss2
Bayes loss: $ols_loss2") OLS loss
Training set:
VI Mean-Field loss: 3.073938403902387
VI Full-Rank loss: 3.081065170821681
Bayes loss: 3.072377656174794
OLS loss: 3.0709261248930093
Test set:
VI Mean-Field loss: 26.07318997207172
VI Full-Rank loss: 25.76208878027834
Bayes loss: 26.12801985061055
OLS loss: 27.09481307076057
Interestingly the squared difference between true- and mean-prediction on the test-set is actually better for the full-rank variational posterior than for the “true” posterior obtained by MCMC sampling using NUTS
. But, as Bayesians, we know that the mean doesn’t tell the entire story. One quick check is to look at the mean predictions ± standard deviation of the two different approaches:
= mapreduce(hcat, 1:5:size(vi_mf_chain, 1)) do i
preds_vi_mf return unstandardize(prediction(vi_mf_chain[i], test), train_unstandardized.MPG)
end
= scatter(
p1 1:size(test, 1),
mean(preds_vi_mf; dims=2);
=std(preds_vi_mf; dims=2),
yerr="prediction (mean ± std)",
label=(900, 500),
size=8,
markersize
)scatter!(1:size(test, 1), unstandardize(test_label, train_unstandardized.MPG); label="true")
xaxis!(1:size(test, 1))
ylims!(10, 40)
title!("VI Mean-Field")
= mapreduce(hcat, 1:5:size(vi_mf_chain, 1)) do i
preds_vi_fr return unstandardize(prediction(vi_fr_chain[i], test), train_unstandardized.MPG)
end
= scatter(
p2 1:size(test, 1),
mean(preds_vi_fr; dims=2);
=std(preds_vi_fr; dims=2),
yerr="prediction (mean ± std)",
label=(900, 500),
size=8,
markersize
)scatter!(1:size(test, 1), unstandardize(test_label, train_unstandardized.MPG); label="true")
xaxis!(1:size(test, 1))
ylims!(10, 40)
title!("VI Full-Rank")
= mapreduce(hcat, 1:5:size(mcmc_chain, 1)) do i
preds_mcmc return unstandardize(prediction(mcmc_chain[i], test), train_unstandardized.MPG)
end
= scatter(
p3 1:size(test, 1),
mean(preds_mcmc; dims=2);
=std(preds_mcmc; dims=2),
yerr="prediction (mean ± std)",
label=(900, 500),
size=8,
markersize
)scatter!(1:size(test, 1), unstandardize(test_label, train_unstandardized.MPG); label="true")
xaxis!(1:size(test, 1))
ylims!(10, 40)
title!("MCMC (NUTS)")
plot(p1, p2, p3; layout=(1, 3), size=(900, 250), label="")
We can see that the full-rank VI approximation is very close to the predictions from MCMC samples. Also, the coverage of full-rank VI and MCMC is much better the crude mean-field approximation.
Footnotes
Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of Machine Learning Research, 18(14).↩︎
Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Ranganath, R., Gerrish, S., & Blei, D. (2014). Black box variational inference. In Proceedings of the International Conference on Artificial intelligence and statistics. PMLR.↩︎
Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In Proceedings of the International Conference on Learning Representations.↩︎
Rezende, D. J., Mohamed, S., & Wierstra, D (2014). Stochastic backpropagation and approximate inference in deep generative models. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of Machine Learning Research, 18(14).↩︎
Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Ranganath, R., Gerrish, S., & Blei, D. (2014). Black box variational inference. In Proceedings of the International Conference on Artificial intelligence and statistics. PMLR.↩︎
Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In Proceedings of the International Conference on Learning Representations.↩︎
Rezende, D. J., Mohamed, S., & Wierstra, D (2014). Stochastic backpropagation and approximate inference in deep generative models. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Khaled, A., Mishchenko, K., & Jin, C. (2023). DoWG unleashed: An efficient universal parameter-free gradient descent method. In Advances in Neural Information Processing Systems, 36.↩︎
Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Domke, J., Gower, R., & Garrigos, G. (2023). Provable convergence guarantees for black-box variational inference. In Advances in Neural Information Processing Systems, 36.↩︎
Kim, K., Oh, J., Wu, K., Ma, Y., & Gardner, J. (2023). On the convergence of black-box variational inference. In Advances in Neural Information Processing Systems, 36.↩︎
Shamir, O., & Zhang, T. (2013). Stochastic gradient descent for non-smooth optimization: Convergence results and optimal averaging schemes. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Khaled, A., Mishchenko, K., & Jin, C. (2023). DoWG unleashed: An efficient universal parameter-free gradient descent method. In Advances in Neural Information Processing Systems, 36.↩︎
Ivgi, M., Hinder, O., & Carmon, Y. (2023). DoG is SGD’s best friend: A parameter-free dynamic step size schedule. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Kingma, D. P., & Ba, J. (2015). Adam: A method for stochastic optimization. In Proceedings of the International Conference on Learning Representations.↩︎
Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎
Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of Machine Learning Research, 18(14).↩︎