Variational Inference

This post will look at variational inference (VI), an optimization approach to approximate Bayesian inference, and how to use it in Turing.jl as an alternative to other approaches such as MCMC. This post will focus on the usage of VI in Turing rather than the principles and theory underlying VI. If you are interested in understanding the mathematics you can checkout our write-up or any other resource online (there are a lot of great ones).

Let’s start with a minimal example. Consider a Turing.Model, which we denote as model. Approximating the posterior associated with model via VI is as simple as

m = model(data...)               # instantiate model on the data
q_init = q_fullrank_gaussian(m)  # initial variational approximation
vi(m, q_init, 1000) # perform VI with the default algorithm on `m` for 1000 iterations

Thus, it’s no more work than standard MCMC sampling in Turing. The default algorithm uses stochastic gradient descent to minimize the (exclusive) KL divergence. This is commonly referred to as automatic differentiation variational inference1, stochastic gradient VI2, and black-box variational inference3 with the reparameterization gradient456.

To get a bit more into what we can do with VI, let’s look at a more concrete example. We will reproduce the tutorial on Bayesian linear regression using VI instead of MCMC. After that, we will discuss how to customize the behavior of vi for more advanced usage.

Let’s first import the relevant packages:

using Random
using Turing
using Turing: Variational
using AdvancedVI
using Plots

Random.seed!(42);

Bayesian Linear Regression Example

Let’s start by setting up our example. We will re-use the Bayesian linear regression example. As we’ll see, there is really no additional work required to apply variational inference to a more complex Model.

using FillArrays
using RDatasets

using LinearAlgebra

# Import the "Default" dataset.
data = RDatasets.dataset("datasets", "mtcars");

# Show the first six rows of the dataset.
first(data, 6)
6×12 DataFrame
Row Model MPG Cyl Disp HP DRat WT QSec VS AM Gear Carb
String31 Float64 Int64 Float64 Int64 Float64 Float64 Float64 Int64 Int64 Int64 Int64
1 Mazda RX4 21.0 6 160.0 110 3.9 2.62 16.46 0 1 4 4
2 Mazda RX4 Wag 21.0 6 160.0 110 3.9 2.875 17.02 0 1 4 4
3 Datsun 710 22.8 4 108.0 93 3.85 2.32 18.61 1 1 4 1
4 Hornet 4 Drive 21.4 6 258.0 110 3.08 3.215 19.44 1 0 3 1
5 Hornet Sportabout 18.7 8 360.0 175 3.15 3.44 17.02 0 0 3 2
6 Valiant 18.1 6 225.0 105 2.76 3.46 20.22 1 0 3 1
# Function to split samples.
function split_data(df, at=0.70)
    r = size(df, 1)
    index = Int(round(r * at))
    train = df[1:index, :]
    test = df[(index + 1):end, :]
    return train, test
end

# A handy helper function to rescale our dataset.
function standardize(x)
    return (x .- mean(x; dims=1)) ./ std(x; dims=1)
end

function standardize(x, orig)
    return (x .- mean(orig; dims=1)) ./ std(orig; dims=1)
end

# Another helper function to unstandardize our datasets.
function unstandardize(x, orig)
    return x .* std(orig; dims=1) .+ mean(orig; dims=1)
end

function unstandardize(x, mean_train, std_train)
    return x .* std_train .+ mean_train
end
unstandardize (generic function with 2 methods)
# Remove the model column.
select!(data, Not(:Model))

# Split our dataset 70%/30% into training/test sets.
train, test = split_data(data, 0.7)
train_unstandardized = copy(train)

# Standardize both datasets.
std_train = standardize(Matrix(train))
std_test = standardize(Matrix(test), Matrix(train))

# Save dataframe versions of our dataset.
train_cut = DataFrame(std_train, names(data))
test_cut = DataFrame(std_test, names(data))

# Create our labels. These are the values we are trying to predict.
train_label = train_cut[:, :MPG]
test_label = test_cut[:, :MPG]

# Get the list of columns to keep.
remove_names = filter(x -> !in(x, ["MPG"]), names(data))

# Filter the test and train sets.
train = Matrix(train_cut[:, remove_names]);
test = Matrix(test_cut[:, remove_names]);
# Bayesian linear regression.
@model function linear_regression(x, y, n_obs, n_vars, ::Type{T}=Vector{Float64}) where {T}
    # Set variance prior.
    σ² ~ truncated(Normal(0, 100); lower=0)

    # Set intercept prior.
    intercept ~ Normal(0, 3)

    # Set the priors on our coefficients.
    coefficients ~ MvNormal(Zeros(n_vars), 10.0 * I)

    # Calculate all the mu terms.
    mu = intercept .+ x * coefficients
    return y ~ MvNormal(mu, σ² * I)
end;
n_obs, n_vars = size(train)
m = linear_regression(train, train_label, n_obs, n_vars);

Basic Usage

To run VI, we must first set a variational family. For instance, the most commonly used family is the mean-field Gaussian family. For this, Turing provides functions that automatically construct the initialization corresponding to the model m:

q_init = q_meanfield_gaussian(m);

vi will automatically recognize the variational family through the type of q_init. Here is a detailed documentation for the constructor:

@doc(Variational.q_meanfield_gaussian)
q_meanfield_gaussian(
    [rng::Random.AbstractRNG,]
    model::DynamicPPL.Model;
    location::Union{Nothing,<:AbstractVector} = nothing,
    scale::Union{Nothing,<:Diagonal} = nothing,
    kwargs...
)

Find a numerically non-degenerate mean-field Gaussian q for approximating the target model.

Arguments

  • model: The target DynamicPPL.Model.

Keyword Arguments

  • location: The location parameter of the initialization. If nothing, a vector of zeros is used.

  • scale: The scale parameter of the initialization. If nothing, an identity matrix is used.

The remaining keyword arguments are passed to q_locationscale.

Returns

  • q::Bijectors.TransformedDistribution: A AdvancedVI.LocationScale distribution matching the support of model.

As we can see, the precise initialization can be customized through the keyword arguments.

Let’s run VI with the default setting:

n_iters = 1000
q_avg, q_last, info, state = vi(m, q_init, n_iters; show_progress=false);

The default setting uses the AdvancedVI.RepGradELBO objective, which corresponds to a variant of what is known as automatic differentiation VI7 or stochastic gradient VI8 or black-box VI9 with the reparameterization gradient101112. The default optimizer we use is AdvancedVI.DoWG13 combined with a proximal operator. (The use of proximal operators with VI on a location-scale family is discussed in detail by J. Domke1415 and others16.) We will take a deeper look into the returned values and the keyword arguments in the following subsections. First, here is the full documentation for vi:

@doc(Variational.vi)
vi(
    [rng::Random.AbstractRNG,]
    model::DynamicPPL.Model;
    q,
    n_iterations::Int;
    objective::AdvancedVI.AbstractVariationalObjective = AdvancedVI.RepGradELBO(
        10; entropy = AdvancedVI.ClosedFormEntropyZeroGradient()
    ),
    show_progress::Bool = Turing.PROGRESS[],
    optimizer::Optimisers.AbstractRule = AdvancedVI.DoWG(),
    averager::AdvancedVI.AbstractAverager = AdvancedVI.PolynomialAveraging(),
    operator::AdvancedVI.AbstractOperator = AdvancedVI.ProximalLocationScaleEntropy(),
    adtype::ADTypes.AbstractADType = Turing.DEFAULT_ADTYPE,
    kwargs...
)

Approximating the target model via variational inference by optimizing objective with the initialization q. This is a thin wrapper around AdvancedVI.optimize.

Arguments

  • model: The target DynamicPPL.Model.

  • q: The initial variational approximation.

  • n_iterations: Number of optimization steps.

Keyword Arguments

  • objective: Variational objective to be optimized.

  • show_progress: Whether to show the progress bar.

  • optimizer: Optimization algorithm.

  • averager: Parameter averaging strategy.

  • operator: Operator applied after each optimization step.

  • adtype: Automatic differentiation backend.

See the docs of AdvancedVI.optimize for additional keyword arguments.

Returns

  • q: Variational distribution formed by the last iterate of the optimization run.

  • q_avg: Variational distribution formed by the averaged iterates according to averager.

  • state: Collection of states used for optimization. This can be used to resume from a past call to vi.

  • info: Information generated during the optimization run.

Values Returned by vi

The main output of the algorithm is q_avg, the average of the parameters generated by the optimization algorithm. For computing q_avg, the default setting uses what is known as polynomial averaging17. Usually, q_avg will perform better than the last-iterate q_last. For instance, we can compare the ELBO of the two:

@info("Objective of q_avg and q_last",
    ELBO_q_avg = estimate_objective(AdvancedVI.RepGradELBO(32), q_avg, Turing.Variational.make_logdensity(m)),
    ELBO_q_last = estimate_objective(AdvancedVI.RepGradELBO(32), q_last, Turing.Variational.make_logdensity(m)) 
)
Info: Objective of q_avg and q_last
  ELBO_q_avg = -52.8918031125194
  ELBO_q_last = -54.027615163157115

We can see that ELBO_q_avg is slightly more optimal.

Now, info contains information generated during optimization that could be useful for diagnostics. For the default setting, which is RepGradELBO, it contains the ELBO estimated at each step, which can be plotted as follows:

Plots.plot([i.elbo for i in info], xlabel="Iterations", ylabel="ELBO", label="info")

Since the ELBO is estimated by a small number of samples, it appears noisy. Furthermore, at each step, the ELBO is evaluated on q_last, not q_avg, which is the actual output that we care about. To obtain more accurate ELBO estimates evaluated on q_avg, we have to define a custom callback function.

Custom Callback Functions

To inspect the progress of optimization in more detail, one can define a custom callback function. For example, the following callback function estimates the ELBO on q_avg every 10 steps with a larger number of samples:

function callback(; stat, averaged_params, restructure, kwargs...)
    if mod(stat.iteration, 10) == 1
        q_avg    = restructure(averaged_params)
        obj      = AdvancedVI.RepGradELBO(128)
        elbo_avg = estimate_objective(obj, q_avg, Turing.Variational.make_logdensity(m))
        (elbo_avg = elbo_avg,)
    else
        nothing
    end
end;

The NamedTuple returned by callback will be appended to the corresponding entry of info, and it will also be displayed on the progress meter if show_progress is set as true.

The custom callback can be supplied to vi as a keyword argument:

q_mf, _, info_mf, _ = vi(m, q_init, n_iters; show_progress=false, callback=callback);

Let’s plot the result:

iters   = 1:10:length(info_mf)
elbo_mf = [i.elbo_avg for i in info_mf[iters]]
Plots.plot!(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="callback", ylims=(-200,Inf))

We can see that the ELBO values are less noisy and progress more smoothly due to averaging.

Using Different Optimisers

The default optimiser we use is a proximal variant of DoWG18. For Gaussian variational families, this works well as a default option. Sometimes, the step size of AdvancedVI.DoWG could be too large, resulting in unstable behavior. (In this case, we recommend trying AdvancedVI.DoG19) Or, for whatever reason, it might be desirable to use a different optimiser. Our implementation supports any optimiser that implements the Optimisers.jl interface.

For instance, let’s try using Optimisers.Adam20, which is a popular choice. Since AdvancedVI does not implement a proximal operator for Optimisers.Adam, we must use the AdvancedVI.ClipScale() projection operator, which ensures that the scale matrix of the variational approximation is positive definite. (See the paper by J. Domke 202021 for more detail about the use of a projection operator.)

using Optimisers

_, _, info_adam, _ = vi(m, q_init, n_iters; show_progress=false, callback=callback, optimizer=Optimisers.Adam(3e-3), operator=ClipScale());
iters     = 1:10:length(info_mf)
elbo_adam = [i.elbo_avg for i in info_adam[iters]]
Plots.plot(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="DoWG")
Plots.plot!(iters, elbo_adam, xlabel="Iterations", ylabel="ELBO", label="Adam")

Compared to the default option AdvancedVI.DoWG(), we can see that Optimisers.Adam(3e-3) is converging more slowly. With more step size tuning, it is possible that Optimisers.Adam could perform better or equal. That is, most common optimisers require some degree of tuning to perform better or comparably to AdvancedVI.DoWG() or AdvancedVI.DoG(), which do not require much tuning at all. Due to this fact, they are referred to as parameter-free optimizers.

Using Full-Rank Variational Families

So far, we have only used the mean-field Gaussian family. This, however, approximates the posterior covariance with a diagonal matrix. To model the full covariance matrix, we can use the full-rank Gaussian family2223:

q_init_fr = q_fullrank_gaussian(m);
@doc(Variational.q_fullrank_gaussian)
q_fullrank_gaussian(
    [rng::Random.AbstractRNG,]
    model::DynamicPPL.Model;
    location::Union{Nothing,<:AbstractVector} = nothing,
    scale::Union{Nothing,<:LowerTriangular} = nothing,
    kwargs...
)

Find a numerically non-degenerate Gaussian q with a scale with full-rank factors (traditionally referred to as a "full-rank family") for approximating the target model.

Arguments

  • model: The target DynamicPPL.Model.

Keyword Arguments

  • location: The location parameter of the initialization. If nothing, a vector of zeros is used.

  • scale: The scale parameter of the initialization. If nothing, an identity matrix is used.

The remaining keyword arguments are passed to q_locationscale.

Returns

  • q::Bijectors.TransformedDistribution: A AdvancedVI.LocationScale distribution matching the support of model.

The term full-rank might seem a bit peculiar since covariance matrices are always full-rank. This term, however, traditionally comes from the fact that full-rank families use full-rank factors in addition to the diagonal of the covariance.

In contrast to the mean-field family, the full-rank family will often result in more computation per optimization step and slower convergence, especially in high dimensions:

q_fr, _, info_fr, _ = vi(m, q_init_fr, n_iters; show_progress=false, callback)

Plots.plot(elbo_mf, xlabel="Iterations", ylabel="ELBO", label="Mean-Field", ylims=(-200, Inf))

elbo_fr = [i.elbo_avg for i in info_fr[iters]]
Plots.plot!(elbo_fr, xlabel="Iterations", ylabel="ELBO", label="Full-Rank", ylims=(-200, Inf))

However, we can see that the full-rank families achieve a higher ELBO in the end. Due to the relationship between the ELBO and the Kullback-Leibler divergence, this indicates that the full-rank covariance is much more accurate. This trade-off between statistical accuracy and optimization speed is often referred to as the statistical-computational trade-off. The fact that we can control this trade-off through the choice of variational family is a strength, rather than a limitation, of variational inference.

We can also visualize the covariance matrix.

heatmap(cov(rand(q_fr, 100_000), dims=2))

Obtaining Summary Statistics

Let’s inspect the resulting variational approximation in more detail and compare it against MCMC. To obtain summary statistics from VI, we can draw samples from the resulting variational approximation:

z = rand(q_fr, 100_000);

Now, we can, for example, look at expectations:

avg = vec(mean(z; dims=2))
12-element Vector{Float64}:
  0.381370039029503
 -0.002712666930547315
  0.35993913738443195
 -0.07407809361373849
 -0.09185665913131266
  0.5861630697109698
 -0.03587845396233794
  0.08657968704192678
 -0.0748945529831772
  0.118774737727532
  0.19056418649105789
 -0.5957207566979493

The vector has the same ordering as the parameters in the model, e.g. in this case σ² has index 1, intercept has index 2 and coefficients has indices 3:12. If you forget or you might want to do something programmatically with the result, you can obtain the sym → indices mapping as follows:

using Bijectors: bijector

_, sym2range = bijector(m, Val(true));
sym2range
(intercept = UnitRange{Int64}[2:2], σ² = UnitRange{Int64}[1:1], coefficients = UnitRange{Int64}[3:12])

For example, we can check the sample distribution and mean value of σ²:

histogram(z[1, :])
avg[union(sym2range[:σ²]...)]
1-element Vector{Float64}:
 0.381370039029503
avg[union(sym2range[:intercept]...)]
1-element Vector{Float64}:
 -0.002712666930547315
avg[union(sym2range[:coefficients]...)]
10-element Vector{Float64}:
  0.35993913738443195
 -0.07407809361373849
 -0.09185665913131266
  0.5861630697109698
 -0.03587845396233794
  0.08657968704192678
 -0.0748945529831772
  0.118774737727532
  0.19056418649105789
 -0.5957207566979493

For further convenience, we can wrap the samples into a Chains object to summarize the results.

varinf = Turing.DynamicPPL.VarInfo(m)
vns_and_values = Turing.DynamicPPL.varname_and_value_leaves(Turing.DynamicPPL.values_as(varinf, OrderedDict))
varnames = map(first, vns_and_values)
vi_chain = Chains(reshape(z', (size(z,2), size(z,1), 1)), varnames)
Chains MCMC chain (100000×12×1 reshape(adjoint(::Matrix{Float64}), 100000, 12, 1) with eltype Float64):

Iterations        = 1:1:100000
Number of chains  = 1
Samples per chain = 100000
parameters        = σ², intercept, coefficients[1], coefficients[2], coefficients[3], coefficients[4], coefficients[5], coefficients[6], coefficients[7], coefficients[8], coefficients[9], coefficients[10]

Summary Statistics
        parameters      mean       std      mcse      ess_bulk      ess_tail             Symbol   Float64   Float64   Float64       Float64       Float64   ⋯

                σ²    0.3814    0.1228    0.0004   100503.1171    99052.7657   ⋯
         intercept   -0.0027    0.1258    0.0004    97521.0892   100384.5892   ⋯
   coefficients[1]    0.3599    0.5459    0.0017    99608.2093    99087.2832   ⋯
   coefficients[2]   -0.0741    0.5732    0.0018   101491.7677    99635.6033   ⋯
   coefficients[3]   -0.0919    0.4413    0.0014    99361.0380    99530.5701   ⋯
   coefficients[4]    0.5862    0.3884    0.0012   100263.9999    98920.2938   ⋯
   coefficients[5]   -0.0359    0.5759    0.0018   100576.2245   100011.1552   ⋯
   coefficients[6]    0.0866    0.3394    0.0011   100945.4015    99134.7829   ⋯
   coefficients[7]   -0.0749    0.3511    0.0011    98955.0004   100048.3451   ⋯
   coefficients[8]    0.1188    0.3008    0.0010    99629.3612    97844.1731   ⋯
   coefficients[9]    0.1906    0.3904    0.0012    99034.8858    99383.9414   ⋯
  coefficients[10]   -0.5957    0.4265    0.0014    98214.9128    99085.9496   ⋯
                                                               2 columns omitted

Quantiles
        parameters      2.5%     25.0%     50.0%     75.0%     97.5%
            Symbol   Float64   Float64   Float64   Float64   Float64

                σ²    0.1963    0.2938    0.3630    0.4482    0.6728
         intercept   -0.2498   -0.0875   -0.0027    0.0820    0.2434
   coefficients[1]   -0.7092   -0.0075    0.3605    0.7285    1.4319
   coefficients[2]   -1.1957   -0.4617   -0.0740    0.3114    1.0545
   coefficients[3]   -0.9509   -0.3928   -0.0924    0.2074    0.7723
   coefficients[4]   -0.1709    0.3241    0.5851    0.8481    1.3485
   coefficients[5]   -1.1703   -0.4242   -0.0332    0.3553    1.0852
   coefficients[6]   -0.5791   -0.1415    0.0859    0.3163    0.7509
   coefficients[7]   -0.7637   -0.3121   -0.0750    0.1604    0.6154
   coefficients[8]   -0.4695   -0.0830    0.1178    0.3221    0.7105
   coefficients[9]   -0.5736   -0.0731    0.1904    0.4525    0.9573
  coefficients[10]   -1.4319   -0.8837   -0.5950   -0.3082    0.2415

(Since we’re drawing independent samples, we can simply ignore the ESS and Rhat metrics.) Unfortunately, extracting varnames is a bit verbose at the moment, but hopefully will become simpler in the near future.

Let’s compare this against samples from NUTS:

mcmc_chain = sample(m, NUTS(), 10_000, drop_warmup=true, progress=false);

vi_mean = mean(vi_chain)[:, 2]
mcmc_mean = mean(mcmc_chain, names(mcmc_chain, :parameters))[:, 2]

plot(mcmc_mean; xticks=1:1:length(mcmc_mean), label="mean of NUTS")
plot!(vi_mean; label="mean of VI")
Info: Found initial step size
  ϵ = 0.05

That looks pretty good! But let’s see how the predictive distributions looks for the two.

Making Predictions

Similarily to the linear regression tutorial, we’re going to compare to multivariate ordinary linear regression using the GLM package:

# Import the GLM package.
using GLM

# Perform multivariate OLS.
ols = lm(
    @formula(MPG ~ Cyl + Disp + HP + DRat + WT + QSec + VS + AM + Gear + Carb), train_cut
)

# Store our predictions in the original dataframe.
train_cut.OLSPrediction = unstandardize(GLM.predict(ols), train_unstandardized.MPG)
test_cut.OLSPrediction = unstandardize(GLM.predict(ols, test_cut), train_unstandardized.MPG);
# Make a prediction given an input vector, using mean parameter values from a chain.
function prediction(chain, x)
    p = get_params(chain)
    α = mean(p.intercept)
    β = collect(mean.(p.coefficients))
    return α .+ x * β
end
prediction (generic function with 1 method)
# Unstandardize the dependent variable.
train_cut.MPG = unstandardize(train_cut.MPG, train_unstandardized.MPG)
test_cut.MPG = unstandardize(test_cut.MPG, train_unstandardized.MPG);
# Show the first side rows of the modified dataframe.
first(test_cut, 6)
6×12 DataFrame
Row MPG Cyl Disp HP DRat WT QSec VS AM Gear Carb OLSPrediction
Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64
1 15.2 1.04746 0.565102 0.258882 -0.652405 0.0714991 -0.716725 -0.977008 -0.598293 -0.891883 -0.469126 19.8583
2 13.3 1.04746 0.929057 1.90345 0.380435 0.465717 -1.90403 -0.977008 -0.598293 -0.891883 1.11869 16.0462
3 19.2 1.04746 1.32466 0.691663 -0.777058 0.470584 -0.873777 -0.977008 -0.598293 -0.891883 -0.469126 18.5746
4 27.3 -1.25696 -1.21511 -1.19526 1.0037 -1.38857 0.288403 0.977008 1.59545 1.07026 -1.26303 29.3233
5 26.0 -1.25696 -0.888346 -0.762482 1.62697 -1.18903 -1.09365 -0.977008 1.59545 3.0324 -0.469126 30.7731
6 30.4 -1.25696 -1.08773 -0.381634 0.451665 -1.79933 -0.968007 0.977008 1.59545 3.0324 -0.469126 25.2892
# Construct the Chains from the Variational Approximations
z_mf = rand(q_mf, 10_000);
z_fr = rand(q_fr, 10_000);

vi_mf_chain = Chains(reshape(z_mf', (size(z_mf,2), size(z_mf,1), 1)), varnames);
vi_fr_chain = Chains(reshape(z_fr', (size(z_fr,2), size(z_fr,1), 1)), varnames);
# Calculate the predictions for the training and testing sets using the samples `z` from variational posterior
train_cut.VIMFPredictions = unstandardize(
    prediction(vi_mf_chain, train), train_unstandardized.MPG
)
test_cut.VIMFPredictions = unstandardize(
    prediction(vi_mf_chain, test), train_unstandardized.MPG
)

train_cut.VIFRPredictions = unstandardize(
    prediction(vi_fr_chain, train), train_unstandardized.MPG
)
test_cut.VIFRPredictions = unstandardize(
    prediction(vi_fr_chain, test), train_unstandardized.MPG
)

train_cut.BayesPredictions = unstandardize(
    prediction(mcmc_chain, train), train_unstandardized.MPG
)
test_cut.BayesPredictions = unstandardize(
    prediction(mcmc_chain, test), train_unstandardized.MPG
);
vi_mf_loss1 = mean((train_cut.VIMFPredictions - train_cut.MPG) .^ 2)
vi_fr_loss1 = mean((train_cut.VIFRPredictions - train_cut.MPG) .^ 2)
bayes_loss1 = mean((train_cut.BayesPredictions - train_cut.MPG) .^ 2)
ols_loss1 = mean((train_cut.OLSPrediction - train_cut.MPG) .^ 2)

vi_mf_loss2 = mean((test_cut.VIMFPredictions - test_cut.MPG) .^ 2)
vi_fr_loss2 = mean((test_cut.VIFRPredictions - test_cut.MPG) .^ 2)
bayes_loss2 = mean((test_cut.BayesPredictions - test_cut.MPG) .^ 2)
ols_loss2 = mean((test_cut.OLSPrediction - test_cut.MPG) .^ 2)

println("Training set:
    VI Mean-Field loss: $vi_mf_loss1
    VI Full-Rank loss: $vi_fr_loss1
    Bayes loss: $bayes_loss1
    OLS loss: $ols_loss1
Test set:
    VI Mean-Field loss: $vi_mf_loss2
    VI Full-Rank loss: $vi_fr_loss2
    Bayes loss: $bayes_loss2
    OLS loss: $ols_loss2")
Training set:
    VI Mean-Field loss: 3.073938403902387
    VI Full-Rank loss: 3.081065170821681
    Bayes loss: 3.072377656174794
    OLS loss: 3.0709261248930093
Test set:
    VI Mean-Field loss: 26.07318997207172
    VI Full-Rank loss: 25.76208878027834
    Bayes loss: 26.12801985061055
    OLS loss: 27.09481307076057

Interestingly the squared difference between true- and mean-prediction on the test-set is actually better for the full-rank variational posterior than for the “true” posterior obtained by MCMC sampling using NUTS. But, as Bayesians, we know that the mean doesn’t tell the entire story. One quick check is to look at the mean predictions ± standard deviation of the two different approaches:

preds_vi_mf = mapreduce(hcat, 1:5:size(vi_mf_chain, 1)) do i
    return unstandardize(prediction(vi_mf_chain[i], test), train_unstandardized.MPG)
end

p1 = scatter(
    1:size(test, 1),
    mean(preds_vi_mf; dims=2);
    yerr=std(preds_vi_mf; dims=2),
    label="prediction (mean ± std)",
    size=(900, 500),
    markersize=8,
)
scatter!(1:size(test, 1), unstandardize(test_label, train_unstandardized.MPG); label="true")
xaxis!(1:size(test, 1))
ylims!(10, 40)
title!("VI Mean-Field")

preds_vi_fr = mapreduce(hcat, 1:5:size(vi_mf_chain, 1)) do i
    return unstandardize(prediction(vi_fr_chain[i], test), train_unstandardized.MPG)
end

p2 = scatter(
    1:size(test, 1),
    mean(preds_vi_fr; dims=2);
    yerr=std(preds_vi_fr; dims=2),
    label="prediction (mean ± std)",
    size=(900, 500),
    markersize=8,
)
scatter!(1:size(test, 1), unstandardize(test_label, train_unstandardized.MPG); label="true")
xaxis!(1:size(test, 1))
ylims!(10, 40)
title!("VI Full-Rank")

preds_mcmc = mapreduce(hcat, 1:5:size(mcmc_chain, 1)) do i
    return unstandardize(prediction(mcmc_chain[i], test), train_unstandardized.MPG)
end

p3 = scatter(
    1:size(test, 1),
    mean(preds_mcmc; dims=2);
    yerr=std(preds_mcmc; dims=2),
    label="prediction (mean ± std)",
    size=(900, 500),
    markersize=8,
)
scatter!(1:size(test, 1), unstandardize(test_label, train_unstandardized.MPG); label="true")
xaxis!(1:size(test, 1))
ylims!(10, 40)
title!("MCMC (NUTS)")

plot(p1, p2, p3; layout=(1, 3), size=(900, 250), label="")

We can see that the full-rank VI approximation is very close to the predictions from MCMC samples. Also, the coverage of full-rank VI and MCMC is much better the crude mean-field approximation.

Back to top

Footnotes

  1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of Machine Learning Research, 18(14).↩︎

  2. Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  3. Ranganath, R., Gerrish, S., & Blei, D. (2014). Black box variational inference. In Proceedings of the International Conference on Artificial intelligence and statistics. PMLR.↩︎

  4. Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In Proceedings of the International Conference on Learning Representations.↩︎

  5. Rezende, D. J., Mohamed, S., & Wierstra, D (2014). Stochastic backpropagation and approximate inference in deep generative models. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  6. Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  7. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of Machine Learning Research, 18(14).↩︎

  8. Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  9. Ranganath, R., Gerrish, S., & Blei, D. (2014). Black box variational inference. In Proceedings of the International Conference on Artificial intelligence and statistics. PMLR.↩︎

  10. Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In Proceedings of the International Conference on Learning Representations.↩︎

  11. Rezende, D. J., Mohamed, S., & Wierstra, D (2014). Stochastic backpropagation and approximate inference in deep generative models. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  12. Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  13. Khaled, A., Mishchenko, K., & Jin, C. (2023). DoWG unleashed: An efficient universal parameter-free gradient descent method. In Advances in Neural Information Processing Systems, 36.↩︎

  14. Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  15. Domke, J., Gower, R., & Garrigos, G. (2023). Provable convergence guarantees for black-box variational inference. In Advances in Neural Information Processing Systems, 36.↩︎

  16. Kim, K., Oh, J., Wu, K., Ma, Y., & Gardner, J. (2023). On the convergence of black-box variational inference. In Advances in Neural Information Processing Systems, 36.↩︎

  17. Shamir, O., & Zhang, T. (2013). Stochastic gradient descent for non-smooth optimization: Convergence results and optimal averaging schemes. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  18. Khaled, A., Mishchenko, K., & Jin, C. (2023). DoWG unleashed: An efficient universal parameter-free gradient descent method. In Advances in Neural Information Processing Systems, 36.↩︎

  19. Ivgi, M., Hinder, O., & Carmon, Y. (2023). DoG is SGD’s best friend: A parameter-free dynamic step size schedule. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  20. Kingma, D. P., & Ba, J. (2015). Adam: A method for stochastic optimization. In Proceedings of the International Conference on Learning Representations.↩︎

  21. Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  22. Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  23. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of Machine Learning Research, 18(14).↩︎