Variational Inference

This post will look at variational inference (VI), an optimisation approach to approximate Bayesian inference, and how to use it in Turing.jl as an alternative to other approaches such as MCMC. This post will focus on the usage of VI in Turing rather than the principles and theory underlying VI. If you are interested in understanding the mathematics you can checkout our write-up or any other resource online (there are a lot of great ones).

Let’s start with a minimal example. Consider a Turing.Model, which we denote as model. Approximating the posterior associated with model via VI is as simple as

m = model(data...)               # instantiate model on the data
q_init = q_fullrank_gaussian     # how to generate initial variational approximation
result = vi(m, q_init, 1000)     # perform VI with the default algorithm on `m` for 1000 iterations

Thus, it’s no more work than standard MCMC sampling in Turing. The default algorithm uses stochastic gradient descent to minimise the (exclusive) KL divergence. This approach is commonly referred to as automatic differentiation variational inference (ADVI)1, stochastic gradient VI2, and black-box variational inference3 with the reparameterization gradient456.

To get a bit more into what we can do with VI, let’s look at a more concrete example. We will reproduce the tutorial on Bayesian linear regression using VI instead of MCMC. After that, we will discuss how to customise the behaviour of vi for more advanced usage.

Let’s first import the relevant packages:

using Random
using Turing
using Turing: Variational
using AdvancedVI
using Plots

Random.seed!(42);

Bayesian Linear Regression Example

Let’s start by setting up our example. We will re-use the Bayesian linear regression example. As we’ll see, there is really no additional work required to apply variational inference to a more complex Model.

using FillArrays
using RDatasets

using LinearAlgebra

# Import the "mtcars" dataset.
data = RDatasets.dataset("datasets", "mtcars");

# Show the first six rows of the dataset.
first(data, 6)
6×12 DataFrame
Row Model MPG Cyl Disp HP DRat WT QSec VS AM Gear Carb
String31 Float64 Int64 Float64 Int64 Float64 Float64 Float64 Int64 Int64 Int64 Int64
1 Mazda RX4 21.0 6 160.0 110 3.9 2.62 16.46 0 1 4 4
2 Mazda RX4 Wag 21.0 6 160.0 110 3.9 2.875 17.02 0 1 4 4
3 Datsun 710 22.8 4 108.0 93 3.85 2.32 18.61 1 1 4 1
4 Hornet 4 Drive 21.4 6 258.0 110 3.08 3.215 19.44 1 0 3 1
5 Hornet Sportabout 18.7 8 360.0 175 3.15 3.44 17.02 0 0 3 2
6 Valiant 18.1 6 225.0 105 2.76 3.46 20.22 1 0 3 1
# Function to split samples.
function split_data(df, at=0.70)
    r = size(df, 1)
    index = Int(round(r * at))
    train = df[1:index, :]
    test = df[(index + 1):end, :]
    return train, test
end

# A handy helper function to rescale our dataset.
function standardise(x)
    return (x .- mean(x; dims=1)) ./ std(x; dims=1)
end

function standardise(x, orig)
    return (x .- mean(orig; dims=1)) ./ std(orig; dims=1)
end

# Another helper function to unstandardize our datasets.
function unstandardize(x, orig)
    return x .* std(orig; dims=1) .+ mean(orig; dims=1)
end

function unstandardize(x, mean_train, std_train)
    return x .* std_train .+ mean_train
end
unstandardize (generic function with 2 methods)
# Remove the model column.
select!(data, Not(:Model))

# Split our dataset 70%/30% into training/test sets.
train, test = split_data(data, 0.7)
train_unstandardized = copy(train)

# Standardise both datasets.
std_train = standardise(Matrix(train))
std_test = standardise(Matrix(test), Matrix(train))

# Save dataframe versions of our dataset.
train_cut = DataFrame(std_train, names(data))
test_cut = DataFrame(std_test, names(data))

# Create our labels. These are the values we are trying to predict.
train_label = train_cut[:, :MPG]
test_label = test_cut[:, :MPG]

# Get the list of columns to keep.
remove_names = filter(x -> !in(x, ["MPG"]), names(data))

# Filter the test and train sets.
train = Matrix(train_cut[:, remove_names]);
test = Matrix(test_cut[:, remove_names]);
# Bayesian linear regression.
@model function linear_regression(x, y, n_obs, n_vars, ::Type{T}=Vector{Float64}) where {T}
    # Set variance prior.
    σ² ~ truncated(Normal(0, 100); lower=0)

    # Set intercept prior.
    intercept ~ Normal(0, 3)

    # Set the priors on our coefficients.
    coefficients ~ MvNormal(Zeros(n_vars), 10.0 * I)

    # Calculate all the mu terms.
    mu = intercept .+ x * coefficients
    return y ~ MvNormal(mu, σ² * I)
end;

n_obs, n_vars = size(train)
m = linear_regression(train, train_label, n_obs, n_vars)
DynamicPPL.Model{typeof(linear_regression), (:x, :y, :n_obs, :n_vars, Symbol("##arg#232")), (), (), Tuple{Matrix{Float64}, Vector{Float64}, Int64, Int64, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext, false}(linear_regression, (x = [-0.10474629747086901 -0.5742368526626253 … 1.0702591020190138 1.1186855116463725; -0.10474629747086901 -0.5742368526626253 … 1.0702591020190138 1.1186855116463725; … ; -1.256955569650429 -0.8899286603079263 … -0.8918825850158448 -1.2630320292781625; 1.0474629747086912 0.6758710573112484 … -0.8918825850158448 -0.46912618230331743], y = [0.16337978215187793, 0.16337978215187793, 0.452211897027519, 0.2275646965686868, -0.20568347574477455, -0.30196084736998785, -0.9117175343296745, 0.7089515546947551, 0.452211897027519, -0.12545233272376316  …  -0.4303306762036062, -0.7673014768918542, -1.5375204498935633, -1.5375204498935633, -0.8475326199128657, 1.9926498430309372, 1.6717252709468917, 2.2333432720939714, 0.2436109251728893, -0.7191627910792473], n_obs = 22, n_vars = 10, var"##arg#232" = DynamicPPL.TypeWrap{Vector{Float64}}()), NamedTuple(), DynamicPPL.DefaultContext())

Basic Usage

To run VI, we must first set a variational family. For instance, the most commonly used family is the mean-field Gaussian family. For this, Turing provides functions that automatically construct the initialisation corresponding to the model m. For example, for the mean-field Gaussian family, we can use:

@doc(Variational.q_meanfield_gaussian)
q_meanfield_gaussian(
    [rng::Random.AbstractRNG,]
    ldf::DynamicPPL.LogDensityFunction;
    location::Union{Nothing,<:AbstractVector} = nothing,
    scale::Union{Nothing,<:Diagonal} = nothing,
    kwargs...
)

Find a numerically non-degenerate mean-field Gaussian q for approximating the target ldf::LogDensityFunction.

If the scale set as nothing, the default value will be a zero-mean Gaussian with a Diagonal scale matrix (the "mean-field" approximation) no larger than 0.6*I (covariance of 0.6^2*I). This guarantees that the samples from the initial variational approximation will fall in the range of (-2, 2) with 99.9% probability, which mimics the behavior of the Turing.InitFromUniform() strategy. Whether the default choice is used or not, the scale may be adjusted via q_initialize_scale so that the log-densities of model are finite over the samples from q.

Arguments

  • ldf: The target log-density function.

Keyword Arguments

  • location: The location parameter of the initialization. If nothing, a vector of zeros is used.

  • scale: The scale parameter of the initialization. If nothing, an identity matrix is used.

The remaining keyword arguments are passed to q_locationscale.

Returns

  • An AdvancedVI.LocationScale distribution matching the support of ldf.

As we can see, the precise initialisation can be customized through the keyword arguments.

Let’s run VI with the default setting:

n_iters = 1000
result_avg = vi(m, q_meanfield_gaussian, n_iters; show_progress=false)
[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
VIResult
  ├ q    : MvLocationScale
  ├ info : 1000-element Vector{@NamedTuple{elbo::Float64, iteration::Int64}}
  │        final iteration:
  │         ├ elbo = -55.039471145948
  │         └ iteration = 1000
  └ (2 more fields: state, ldf)

The default setting uses the AdvancedVI.RepGradELBO objective, which corresponds to a variant of what is known as automatic differentiation VI7 or stochastic gradient VI8 or black-box VI9 with the reparameterization gradient101112. The default optimiser we use is AdvancedVI.DoWG13 combined with a proximal operator. (The use of proximal operators with VI on a location-scale family is discussed in detail by J. Domke1415 and others16.) We will take a deeper look into the returned values and the keyword arguments in the following subsections. First, here is the full documentation for vi:

@doc(Variational.vi)
vi(
    [rng::Random.AbstractRNG,]
    model::DynamicPPL.Model,
    family,
    max_iter::Int;
    adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
    algorithm::AdvancedVI.AbstractVariationalAlgorithm = KLMinRepGradProxDescent(
        adtype; n_samples=10
    ),
    unconstrained::Bool=requires_unconstrained_space(algorithm),
    fix_transforms::Bool=false,
    show_progress::Bool = Turing.PROGRESS[],
    kwargs...
)

Approximate the target model via the variational inference algorithm algorithm using a variational family specified by family. This is a thin wrapper around AdvancedVI.optimize.

The default algorithm, KLMinRepGradProxDescent (relevant docs), assumes family returns a AdvancedVI.MvLocationScale, which is true if family is q_fullrank_gaussian or q_meanfield_gaussian. For other variational families, refer to the documentation of AdvancedVI to determine the best algorithm and other options.

Arguments

  • model: The target DynamicPPL.Model.

  • family: A function which is used to generate an initial variational approximation. Existing choices in Turing are q_locationscale, q_meanfield_gaussian, and q_fullrank_gaussian.

  • max_iter: Maximum number of steps.

  • Any additional arguments are passed on to AdvancedVI.optimize.

Keyword Arguments

  • adtype: Automatic differentiation backend to be applied to the log-density. The default value for algorithm also uses this backend for differentiating the variational objective.

  • algorithm: Variational inference algorithm. The default is KLMinRepGradProxDescent, please refer to AdvancedVI docs for all the options.

  • show_progress: Whether to show the progress bar.

  • unconstrained: Whether to transform the posterior to be unconstrained for running the variational inference algorithm. The default value depends on the chosen algorithm (most algorithms require unconstrained space).

  • fix_transforms: Whether to precompute the transforms needed to convert model parameters to (possibly unconstrained) vectors. This can lead to performance improvements, but if any transforms depend on model parameters, setting fix_transforms=true can silently yield incorrect results.

  • Any additional keyword arguments are passed on both to the function initial_approx, and also to AdvancedVI.optimize.

See the docs of AdvancedVI.optimize for additional keyword arguments.

Returns

A VIResult object: please see its docstring for information.

Values Returned by vi

vi returns a single object, which is a VIResult.

The main output of the algorithm is the q field, which is the variational approximation itself. This represents the average of the parameters generated by the optimisation algorithm, and is computed using what is known as polynomial averaging17.

result_avg.q
AdvancedVI.MvLocationScale{LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Distributions.Normal{Float64}, Vector{Float64}}(
location: [-1.370577232223292, -0.002238039832517606, 0.36984333017795934, -0.10267782351028602, -0.0809317588605051, 0.6051904509755113, -0.0035549406962756976, 0.08379502327038048, -0.06794334563789456, 0.12488444636147826, 0.1889591491241175, -0.6093667705331203]
scale: [0.3261788171235718 0.0 … 0.0 0.0; 0.0 0.10439317655945271 … 0.0 0.0; … ; 0.0 0.0 … 0.11152949244728395 0.0; 0.0 0.0 … 0.0 0.11076139211388376]
dist: Distributions.Normal{Float64}(μ=0.0, σ=1.0)
)
NoteAveraging

Usually, averaging will lead to better performance than simply using the last iterate. If you want the last iterate, you can disable averaging in the algorithm:

result_last = vi(
    m,
    q_meanfield_gaussian,
    n_iters;
    show_progress=false,
    algorithm=KLMinRepGradDescent(
        AutoForwardDiff();
        operator=AdvancedVI.ClipScale(),
        averager=AdvancedVI.NoAveraging()
    ),
);
result_last.q
[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
AdvancedVI.MvLocationScale{LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Distributions.Normal{Float64}, Vector{Float64}}(
location: [-0.21504162166797902, 0.012446511297543199, 0.13903845336976778, -0.027392498457864435, -0.06570618926664787, 0.4449812540648866, -0.14239638987492198, 0.10732677783503532, -0.009381506573922577, 0.13017362091806114, 0.10816612142257093, -0.4218323658540159]
scale: [0.25703090224020275 0.0 … 0.0 0.0; 0.0 0.17533886392073122 … 0.0 0.0; … ; 0.0 0.0 … 0.13098781526771058 0.0; 0.0 0.0 … 0.0 0.15777128688699263]
dist: Distributions.Normal{Float64}(μ=0.0, σ=1.0)
)

To see the difference, we can compare the ELBO of the two:

@info("Objective of q_avg and q_last",
    ELBO_q_avg = estimate_objective(AdvancedVI.RepGradELBO(32), result_avg.q, result_avg.ldf),
    ELBO_q_last = estimate_objective(AdvancedVI.RepGradELBO(32), result_last.q, result_last.ldf)
)
Info: Objective of q_avg and q_last
  ELBO_q_avg = 51.454598200575475
  ELBO_q_last = 68.6237892021251

We can see that ELBO_q_avg is slightly more optimal.

result.info also contains information generated during optimisation that could be useful for diagnostics. For the default setting, which is RepGradELBO, it contains the ELBO estimated at each step, which can be plotted as follows:

Plots.plot([i.elbo for i in result_avg.info], xlabel="Iterations", ylabel="ELBO", label="info")

Since the ELBO is estimated by a small number of samples, it appears noisy. Furthermore, at each step, the ELBO is evaluated on q_last, not q_avg, which is the actual output that we care about. To obtain more accurate ELBO estimates evaluated on q_avg, we have to define a custom callback function.

Custom Callback Functions

To inspect the progress of optimisation in more detail, one can define a custom callback function. For example, the following callback function estimates the ELBO on q_avg every 10 steps with a larger number of samples:

function callback(; iteration, averaged_params, restructure, state, kwargs...)
    if mod(iteration, 10) == 1
        q_avg = restructure(averaged_params)
        obj = AdvancedVI.RepGradELBO(128) # 128 samples for ELBO estimation
        elbo_avg = -estimate_objective(obj, q_avg, state.prob)
        (elbo_avg = elbo_avg,)
    else
        nothing
    end
end;

The NamedTuple returned by callback will be appended to the corresponding entry of info, and it will also be displayed on the progress meter if show_progress is set as true.

The custom callback can be supplied to vi as a keyword argument:

result_mf = vi(m, q_meanfield_gaussian, n_iters; show_progress=false, callback=callback);
[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.

Let’s plot the result:

info_mf = result_mf.info
iters = 1:10:length(info_mf)
elbo_mf = [i.elbo_avg for i in info_mf[iters]]
Plots.plot([i.elbo for i in info_mf], xlabel="Iterations", ylabel="ELBO", label="info", linewidth=0.4)
Plots.plot!(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="callback", ylims=(-200,Inf), linewidth=2)

We can see that the ELBO values are less noisy and progress more smoothly due to averaging.

Using Different Optimisers

The default optimiser we use is a proximal variant of DoWG18. For Gaussian variational families, this works well as a default option. Sometimes, the step size of AdvancedVI.DoWG could be too large, resulting in unstable behaviour. (In this case, we recommend trying AdvancedVI.DoG19) Or, for whatever reason, it might be desirable to use a different optimiser. Our implementation supports any optimiser that implements the Optimisers.jl interface.

For instance, let’s try using Optimisers.Adam20, which is a popular choice. Since AdvancedVI does not implement a proximal operator for Optimisers.Adam, we must use the AdvancedVI.ClipScale() projection operator, which ensures that the scale matrix of the variational approximation is positive definite. (See the paper by J. Domke 202021 for more detail about the use of a projection operator.)

using Optimisers

result = vi(
    m, q_meanfield_gaussian, n_iters;
    show_progress=false,
    callback=callback,
    algorithm=KLMinRepGradDescent(AutoForwardDiff(); optimizer=Optimisers.Adam(3e-3), operator=ClipScale())
);
[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.
info_adam = result.info
iters = 1:10:length(info_adam)
elbo_adam = [i.elbo_avg for i in info_adam[iters]]
Plots.plot(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="DoWG")
Plots.plot!(iters, elbo_adam, xlabel="Iterations", ylabel="ELBO", label="Adam")

Compared to the default option AdvancedVI.DoWG(), we can see that Optimisers.Adam(3e-3) is converging more slowly. With more step size tuning, it is possible that Optimisers.Adam could perform better or equal. That is, most common optimisers require some degree of tuning to perform better or comparably to AdvancedVI.DoWG() or AdvancedVI.DoG(), which do not require much tuning at all. Due to this fact, they are referred to as parameter-free optimizers.

Using Full-Rank Variational Families

So far, we have only used the mean-field Gaussian family. This, however, approximates the posterior covariance with a diagonal matrix. To model the full covariance matrix, we can use the full-rank Gaussian family2223:

@doc(Variational.q_fullrank_gaussian)
q_fullrank_gaussian(
    [rng::Random.AbstractRNG,]
    ldf::DynamicPPL.LogDensityFunction;
    location::Union{Nothing,<:AbstractVector} = nothing,
    scale::Union{Nothing,<:LowerTriangular} = nothing,
    kwargs...
)

Find a numerically non-degenerate Gaussian q with a scale with full-rank factors (traditionally referred to as a "full-rank family") for approximating the target ldf::LogDensityFunction.

If the scale set as nothing, the default value will be a zero-mean Gaussian with a LowerTriangular scale matrix (resulting in a covariance with "full-rank" factors) no larger than 0.6*I (covariance of 0.6^2*I). This guarantees that the samples from the initial variational approximation will fall in the range of (-2, 2) with 99.9% probability, which mimics the behavior of the Turing.InitFromUniform() strategy. Whether the default choice is used or not, the scale may be adjusted via q_initialize_scale so that the log-densities of model are finite over the samples from q.

Arguments

  • ldf: The target log-density function.

Keyword Arguments

  • location: The location parameter of the initialization. If nothing, a vector of zeros is used.

  • scale: The scale parameter of the initialization. If nothing, an identity matrix is used.

The remaining keyword arguments are passed to q_locationscale.

Returns

  • An AdvancedVI.LocationScale distribution matching the support of ldf.

The term full-rank might seem a bit peculiar since covariance matrices are always full-rank. This term, however, traditionally comes from the fact that full-rank families use full-rank factors in addition to the diagonal of the covariance.

In contrast to the mean-field family, the full-rank family will often result in more computation per optimisation step and slower convergence, especially in high dimensions:

result_fr = vi(m, q_fullrank_gaussian, n_iters; show_progress=false, callback)

Plots.plot(elbo_mf, xlabel="Iterations", ylabel="ELBO", label="Mean-Field", ylims=(-200, Inf))

elbo_fr = [i.elbo_avg for i in result_fr.info[iters]]
Plots.plot!(elbo_fr, xlabel="Iterations", ylabel="ELBO", label="Full-Rank", ylims=(-200, Inf))
[ Info: The capability of the supplied target `LogDensityProblem` LogDensityProblems.LogDensityOrder{1}() is >= `LogDensityProblems.LogDensityOrder{1}()`. To make use of this, the `adtype` argument for AdvancedVI must be one of `AutoReverseDiff`, `AutoZygote`, `AutoMooncake`, or `AutoEnzyme` in reverse mode.

However, we can see that the full-rank families achieve a higher ELBO in the end. Due to the relationship between the ELBO and the Kullback–Leibler divergence, this indicates that the full-rank covariance is much more accurate. This trade-off between statistical accuracy and optimisation speed is often referred to as the statistical-computational trade-off. The fact that we can control this trade-off through the choice of variational family is a strength, rather than a limitation, of variational inference.

Obtaining Summary Statistics

Let’s inspect the resulting variational approximation in more detail and compare it against MCMC. To obtain summary statistics from VI, we can draw samples from the resulting variational approximation by calling rand on the VIResult. Note that even though VI is often performed in transformed space, rand(result) will return samples in the original space of the parameters.

rand(result_fr)
VarNamedTuple
├─ σ² => 0.25200558565903575
├─ intercept => -0.09840830441493338
└─ coefficients => [1.398319273658256, -0.9150343171917702, 0.04933566000406464, 0.8062034705455388, 0.39327430346864284, 0.6709130436782058, -0.6412192737889927, 0.1533691597033213, 0.5355775921267156, -1.1553516071072405]
NoteSamples in transformed space

If you want to obtain samples in transformed space, you can access result_fr.ldf to obtain the LogDensityFunction used for inference. Calling rand(result_fr.ldf) will return a vector of samples in the transformed space. However, note that these vectors are harder to interpret and work with.

Each sample is a DynamicPPL.VarNamedTuple, and a matrix of these can be converted into a Chains object.

z = rand(result_fr, 100_000, 1)
using AbstractMCMC: AbstractMCMC
vi_chain = AbstractMCMC.from_samples(MCMCChains.Chains, z)
Chains MCMC chain (100000×12×1 Array{Float64, 3}):

Iterations        = 1:1:100000
Number of chains  = 1
Samples per chain = 100000
parameters        = σ², intercept, coefficients[1], coefficients[2], coefficients[3], coefficients[4], coefficients[5], coefficients[6], coefficients[7], coefficients[8], coefficients[9], coefficients[10]
internals         = 

Use `describe(chains)` for summary statistics and quantiles.

Now, we can, for example, look at expectations:

describe(vi_chain)
Chains MCMC chain (100000×12×1 Array{Float64, 3}):

Iterations        = 1:1:100000
Number of chains  = 1
Samples per chain = 100000
parameters        = σ², intercept, coefficients[1], coefficients[2], coefficients[3], coefficients[4], coefficients[5], coefficients[6], coefficients[7], coefficients[8], coefficients[9], coefficients[10]
internals         = 

Summary Statistics

        parameters      mean       std      mcse      ess_bulk     ess_tail              Symbol   Float64   Float64   Float64       Float64      Float64    ⋯

                σ²    0.2979    0.0944    0.0003    99377.6822   96582.0412    ⋯
         intercept   -0.0016    0.1125    0.0004   100156.7778   98715.2905    ⋯
   coefficients[1]    0.3672    0.4919    0.0016    99095.4982   99338.7695    ⋯
   coefficients[2]   -0.0848    0.5018    0.0016    98983.5046   98888.4122    ⋯
   coefficients[3]   -0.0768    0.3936    0.0012    99138.3295   99291.6470    ⋯
   coefficients[4]    0.6037    0.3461    0.0011    99701.6805   99584.8467    ⋯
   coefficients[5]   -0.0276    0.5159    0.0016    99629.0600   97011.6786    ⋯
   coefficients[6]    0.0958    0.3022    0.0010   100088.3743   99418.5105    ⋯
   coefficients[7]   -0.0633    0.3078    0.0010    99329.5193   99842.4190    ⋯
   coefficients[8]    0.1278    0.2663    0.0008    99707.9570   99166.1583    ⋯
   coefficients[9]    0.1865    0.3432    0.0011    99439.5319   97843.6277    ⋯
  coefficients[10]   -0.5910    0.3832    0.0012    99307.2788   97288.4772    ⋯

                                                               2 columns omitted

Quantiles

        parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
            Symbol   Float64   Float64   Float64   Float64   Float64 

                σ²    0.1549    0.2306    0.2840    0.3493    0.5218
         intercept   -0.2221   -0.0778   -0.0016    0.0738    0.2195
   coefficients[1]   -0.5953    0.0342    0.3669    0.6980    1.3360
   coefficients[2]   -1.0711   -0.4237   -0.0867    0.2552    0.9033
   coefficients[3]   -0.8473   -0.3408   -0.0780    0.1889    0.6989
   coefficients[4]   -0.0748    0.3697    0.6046    0.8365    1.2790
   coefficients[5]   -1.0342   -0.3762   -0.0269    0.3208    0.9809
   coefficients[6]   -0.4962   -0.1075    0.0964    0.2992    0.6889
   coefficients[7]   -0.6632   -0.2728   -0.0634    0.1433    0.5424
   coefficients[8]   -0.3943   -0.0526    0.1288    0.3071    0.6481
   coefficients[9]   -0.4886   -0.0454    0.1864    0.4182    0.8606
  coefficients[10]   -1.3439   -0.8483   -0.5907   -0.3341    0.1617

(Since we’re drawing independent samples, we can simply ignore the ESS and Rhat metrics.)

Let’s compare this against samples from NUTS:

mcmc_chain = sample(m, NUTS(), 10_000; progress=false);

vi_mean = mean(vi_chain)[:, 2]
mcmc_mean = mean(mcmc_chain, names(mcmc_chain, :parameters))[:, 2]

plot(mcmc_mean; xticks=1:1:length(mcmc_mean), label="mean of NUTS")
plot!(vi_mean; label="mean of VI")
Info: Found initial step size
  ϵ = 0.4
Warning: There were 31 divergent transitions. Consider reparameterising your model or using a smaller step size. For adaptive samplers such as NUTS and HMCDA, consider increasing `target_accept`.
@ Turing.Inference ~/.julia/packages/Turing/tbwJL/src/mcmc/hmc.jl:481

That looks pretty good! But let’s see how the predictive distributions looks for the two.

Making Predictions

Similarily to the linear regression tutorial, we’re going to compare to multivariate ordinary linear regression using the GLM package:

# Import the GLM package.
using GLM

# Perform multivariate OLS.
ols = lm(
    @formula(MPG ~ Cyl + Disp + HP + DRat + WT + QSec + VS + AM + Gear + Carb), train_cut
)

# Store our predictions in the original dataframe.
train_cut.OLSPrediction = unstandardize(GLM.predict(ols), train_unstandardized.MPG)
test_cut.OLSPrediction = unstandardize(GLM.predict(ols, test_cut), train_unstandardized.MPG);
# Make a prediction given an input vector, using mean parameter values from a chain.
function prediction(chain, x)
    p = get_params(chain)
    α = mean(p.intercept)
    β = collect(mean.(p.coefficients))
    return α .+ x * β
end
prediction (generic function with 1 method)
# Unstandardize the dependent variable.
train_cut.MPG = unstandardize(train_cut.MPG, train_unstandardized.MPG)
test_cut.MPG = unstandardize(test_cut.MPG, train_unstandardized.MPG);
# Show the first side rows of the modified dataframe.
first(test_cut, 6)
6×12 DataFrame
Row MPG Cyl Disp HP DRat WT QSec VS AM Gear Carb OLSPrediction
Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64
1 15.2 1.04746 0.565102 0.258882 -0.652405 0.0714991 -0.716725 -0.977008 -0.598293 -0.891883 -0.469126 19.8583
2 13.3 1.04746 0.929057 1.90345 0.380435 0.465717 -1.90403 -0.977008 -0.598293 -0.891883 1.11869 16.0462
3 19.2 1.04746 1.32466 0.691663 -0.777058 0.470584 -0.873777 -0.977008 -0.598293 -0.891883 -0.469126 18.5746
4 27.3 -1.25696 -1.21511 -1.19526 1.0037 -1.38857 0.288403 0.977008 1.59545 1.07026 -1.26303 29.3233
5 26.0 -1.25696 -0.888346 -0.762482 1.62697 -1.18903 -1.09365 -0.977008 1.59545 3.0324 -0.469126 30.7731
6 30.4 -1.25696 -1.08773 -0.381634 0.451665 -1.79933 -0.968007 0.977008 1.59545 3.0324 -0.469126 25.2892
# Construct the Chains from the Variational Approximations
vi_mf_chain = AbstractMCMC.from_samples(MCMCChains.Chains, rand(result_mf, 10_000, 1))
vi_fr_chain = AbstractMCMC.from_samples(MCMCChains.Chains, rand(result_fr, 10_000, 1))
Chains MCMC chain (10000×12×1 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 1
Samples per chain = 10000
parameters        = σ², intercept, coefficients[1], coefficients[2], coefficients[3], coefficients[4], coefficients[5], coefficients[6], coefficients[7], coefficients[8], coefficients[9], coefficients[10]
internals         = 

Use `describe(chains)` for summary statistics and quantiles.
# Calculate the predictions for the training and testing sets using the samples `z` from variational posterior
train_cut.VIMFPredictions = unstandardize(
    prediction(vi_mf_chain, train), train_unstandardized.MPG
)
test_cut.VIMFPredictions = unstandardize(
    prediction(vi_mf_chain, test), train_unstandardized.MPG
)

train_cut.VIFRPredictions = unstandardize(
    prediction(vi_fr_chain, train), train_unstandardized.MPG
)
test_cut.VIFRPredictions = unstandardize(
    prediction(vi_fr_chain, test), train_unstandardized.MPG
)

train_cut.BayesPredictions = unstandardize(
    prediction(mcmc_chain, train), train_unstandardized.MPG
)
test_cut.BayesPredictions = unstandardize(
    prediction(mcmc_chain, test), train_unstandardized.MPG
);
vi_mf_loss1 = mean((train_cut.VIMFPredictions - train_cut.MPG) .^ 2)
vi_fr_loss1 = mean((train_cut.VIFRPredictions - train_cut.MPG) .^ 2)
bayes_loss1 = mean((train_cut.BayesPredictions - train_cut.MPG) .^ 2)
ols_loss1 = mean((train_cut.OLSPrediction - train_cut.MPG) .^ 2)

vi_mf_loss2 = mean((test_cut.VIMFPredictions - test_cut.MPG) .^ 2)
vi_fr_loss2 = mean((test_cut.VIFRPredictions - test_cut.MPG) .^ 2)
bayes_loss2 = mean((test_cut.BayesPredictions - test_cut.MPG) .^ 2)
ols_loss2 = mean((test_cut.OLSPrediction - test_cut.MPG) .^ 2)

println("Training set:
    VI Mean-Field loss: $vi_mf_loss1
    VI Full-Rank loss: $vi_fr_loss1
    Bayes loss: $bayes_loss1
    OLS loss: $ols_loss1
Test set:
    VI Mean-Field loss: $vi_mf_loss2
    VI Full-Rank loss: $vi_fr_loss2
    Bayes loss: $bayes_loss2
    OLS loss: $ols_loss2")
Training set:
    VI Mean-Field loss: 3.075162701731674
    VI Full-Rank loss: 3.079297574755324
    Bayes loss: 3.0730877712644475
    OLS loss: 3.07092612489301
Test set:
    VI Mean-Field loss: 25.969505420207565
    VI Full-Rank loss: 24.952749829168333
    Bayes loss: 25.93074787116897
    OLS loss: 27.094813070760562

Interestingly the squared difference between true- and mean-prediction on the test-set is actually better for the full-rank variational posterior than for the “true” posterior obtained by MCMC sampling using NUTS. But, as Bayesians, we know that the mean doesn’t tell the entire story. One quick check is to look at the mean predictions ± standard deviation of the two different approaches:

preds_vi_mf = mapreduce(hcat, 1:5:size(vi_mf_chain, 1)) do i
    return unstandardize(prediction(vi_mf_chain[i], test), train_unstandardized.MPG)
end

p1 = scatter(
    1:size(test, 1),
    mean(preds_vi_mf; dims=2);
    yerr=std(preds_vi_mf; dims=2),
    label="prediction (mean ± std)",
    size=(900, 500),
    markersize=8,
)
scatter!(1:size(test, 1), unstandardize(test_label, train_unstandardized.MPG); label="true")
xaxis!(1:size(test, 1))
ylims!(10, 40)
title!("VI Mean-Field")

preds_vi_fr = mapreduce(hcat, 1:5:size(vi_mf_chain, 1)) do i
    return unstandardize(prediction(vi_fr_chain[i], test), train_unstandardized.MPG)
end

p2 = scatter(
    1:size(test, 1),
    mean(preds_vi_fr; dims=2);
    yerr=std(preds_vi_fr; dims=2),
    label="prediction (mean ± std)",
    size=(900, 500),
    markersize=8,
)
scatter!(1:size(test, 1), unstandardize(test_label, train_unstandardized.MPG); label="true")
xaxis!(1:size(test, 1))
ylims!(10, 40)
title!("VI Full-Rank")

preds_mcmc = mapreduce(hcat, 1:5:size(mcmc_chain, 1)) do i
    return unstandardize(prediction(mcmc_chain[i], test), train_unstandardized.MPG)
end

p3 = scatter(
    1:size(test, 1),
    mean(preds_mcmc; dims=2);
    yerr=std(preds_mcmc; dims=2),
    label="prediction (mean ± std)",
    size=(900, 500),
    markersize=8,
)
scatter!(1:size(test, 1), unstandardize(test_label, train_unstandardized.MPG); label="true")
xaxis!(1:size(test, 1))
ylims!(10, 40)
title!("MCMC (NUTS)")

plot(p1, p2, p3; layout=(1, 3), size=(900, 250), label="")

We can see that the full-rank VI approximation is very close to the predictions from MCMC samples. Also, the coverage of full-rank VI and MCMC is much better the crude mean-field approximation.

Back to top

Footnotes

  1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of Machine Learning Research, 18(14).↩︎

  2. Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  3. Ranganath, R., Gerrish, S., & Blei, D. (2014). Black box variational inference. In Proceedings of the International Conference on Artificial intelligence and statistics. PMLR.↩︎

  4. Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In Proceedings of the International Conference on Learning Representations.↩︎

  5. Rezende, D. J., Mohamed, S., & Wierstra, D (2014). Stochastic backpropagation and approximate inference in deep generative models. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  6. Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  7. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of Machine Learning Research, 18(14).↩︎

  8. Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  9. Ranganath, R., Gerrish, S., & Blei, D. (2014). Black box variational inference. In Proceedings of the International Conference on Artificial intelligence and statistics. PMLR.↩︎

  10. Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In Proceedings of the International Conference on Learning Representations.↩︎

  11. Rezende, D. J., Mohamed, S., & Wierstra, D (2014). Stochastic backpropagation and approximate inference in deep generative models. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  12. Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  13. Khaled, A., Mishchenko, K., & Jin, C. (2023). DoWG unleashed: An efficient universal parameter-free gradient descent method. In Advances in Neural Information Processing Systems, 36.↩︎

  14. Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  15. Domke, J., Gower, R., & Garrigos, G. (2023). Provable convergence guarantees for black-box variational inference. In Advances in Neural Information Processing Systems, 36.↩︎

  16. Kim, K., Oh, J., Wu, K., Ma, Y., & Gardner, J. (2023). On the convergence of black-box variational inference. In Advances in Neural Information Processing Systems, 36.↩︎

  17. Shamir, O., & Zhang, T. (2013). Stochastic gradient descent for non-smooth optimisation: Convergence results and optimal averaging schemes. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  18. Khaled, A., Mishchenko, K., & Jin, C. (2023). DoWG unleashed: An efficient universal parameter-free gradient descent method. In Advances in Neural Information Processing Systems, 36.↩︎

  19. Ivgi, M., Hinder, O., & Carmon, Y. (2023). DoG is SGD’s best friend: A parameter-free dynamic step size schedule. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  20. Kingma, D. P., & Ba, J. (2015). Adam: A method for stochastic optimisation. In Proceedings of the International Conference on Learning Representations.↩︎

  21. Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  22. Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In Proceedings of the International Conference on Machine Learning. PMLR.↩︎

  23. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of Machine Learning Research, 18(14).↩︎