Basic Example

In this tutorial, we will demonstrate the basic usage of AdvancedVI with LogDensityProblem interface.

Problem Setup

Let's consider a basic logistic regression example with a hierarchical prior. For a dataset $(X, y)$ with the design matrix $X \in \mathbb{R}^{n \times d}$ and the response variables $y \in \{0, 1\}^n$, we assume the following data generating process:

\[\begin{aligned} \sigma &\sim \text{LogNormal}(0, 3) \\ \beta &\sim \text{Normal}\left(0_d, \sigma^2 \mathrm{I}_d\right) \\ y &\sim \mathrm{BernoulliLogit}\left(X \beta\right) \end{aligned}\]

The LogDensityProblem corresponding to this model can be constructed as

using LogDensityProblems: LogDensityProblems
using Distributions
using FillArrays

struct LogReg{XType,YType}
    X::XType
    y::YType
end

function LogDensityProblems.logdensity(model::LogReg, θ)
    (; X, y) = model
    d = size(X, 2)
    β, σ = θ[1:d], θ[end]

    logprior_β = logpdf(MvNormal(Zeros(d), σ), β)
    logprior_σ = logpdf(LogNormal(0, 3), σ)

    logit = X*β
    loglike_y = mapreduce((li, yi) -> logpdf(BernoulliLogit(li), yi), +, logit, y)
    return loglike_y + logprior_β + logprior_σ
end

function LogDensityProblems.dimension(model::LogReg)
    return size(model.X, 2) + 1
end

function LogDensityProblems.capabilities(::Type{<:LogReg})
    return LogDensityProblems.LogDensityOrder{0}()
end
nothing

Since the support of σ is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a bijector to transform θ. We will use Bijectors for this purpose. This corresponds to the automatic differentiation variational inference (ADVI) formulation[KTRGB2017].

In our case, we need a bijector that applies an identity map for the first size(X,2) coordinates, and map the last coordinate to the support of LogNormal(0, 3). This can be done as follows:

using Bijectors: Bijectors

function Bijectors.bijector(model::LogReg)
    d = size(model.X, 2)
    return Bijectors.Stacked(
        Bijectors.bijector.([MvNormal(Zeros(d), 1.0), LogNormal(0, 3)]),
        [1:d, (d + 1):(d + 1)],
    )
end
nothing

For more details, please refer to the documentation of Bijectors.

For the dataset, we will use the popular sonar classification dataset from the UCI repository. This can be automatically downloaded using OpenML. The sonar dataset corresponds to the dataset id 40.

using OpenML: OpenML
using DataFrames: DataFrames
data = Array(DataFrames.DataFrame(OpenML.load(40)))
X = Matrix{Float64}(data[:, 1:(end - 1)])
y = Vector{Bool}(data[:, end] .== "Mine")
nothing

Let's apply some basic pre-processing and add an intercept column:

using Statistics

X = (X .- mean(X; dims=2)) ./ std(X; dims=2)
X = hcat(X, ones(size(X, 1)))
nothing

The model can now be instantiated as follows:

model = LogReg(X, y)
nothing

Basic Usage

For the VI algorithm, we will use KLMinRepGradDescent:

using ADTypes, ReverseDiff
using AdvancedVI

alg = KLMinRepGradDescent(ADTypes.AutoReverseDiff())
nothing

This algorithm minimizes the exclusive/reverse KL divergence via stochastic gradient descent in the (Euclidean) space of the parameters of the variational approximation with the reparametrization gradient[TL2014][RMW2014][KW2014]. This is also commonly referred as automatic differentiation VI, black-box VI, stochastic gradient VI, and so on. KLMinRepGradDescent, in particular, assumes that the target LogDensityProblem is differentiable. If the LogDensityProblem has a differentiation capability of at least first-order, we can take advantage of this.

For this example, we will use LogDensityProblemsAD to equip our problem with a first-order capability:

using DifferentiationInterface: DifferentiationInterface
using LogDensityProblemsAD: LogDensityProblemsAD

model_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), model)
nothing

For the variational family, we will consider a FullRankGaussian approximation:

using LinearAlgebra

d = LogDensityProblems.dimension(model_ad)
q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.37*I, d, d)))
nothing

The bijector can now be applied to q to match the support of the target problem.

b = Bijectors.bijector(model)
binv = Bijectors.inverse(b)
q_transformed = Bijectors.TransformedDistribution(q, binv)
nothing

We can now run VI:

max_iter = 10^4
q_out, info, _ = AdvancedVI.optimize(
    alg, max_iter, model_ad, q_transformed; show_progress=false
)
nothing

Let's verify that the optimization procedure converged. For this, we will visually inspect that the maximization objective of KLMinRepGradDescent, the "evidence lower bound" (ELBO) increased. Since KLMinRepGradDescent stores the ELBO estimate at each iteration in info, we can visualize this as follows:

using Plots

plot(
    [i.iteration for i in info],
    [i.elbo for i in info];
    xlabel="Iteration",
    ylabel="ELBO",
    label=nothing,
)
savefig("basic_example_elbo.svg")
nothing

Custom Callback

The ELBO estimates above however, use only a handful of Monte Carlo samples. Furthermore, the ELBO is evaluated on the iterates of the optimization procedure, which may not coincide with the actual output of the algorithm. (For instance, if parameter averaging is used.) Therefore, we may want to occasionally estimate higher resolution ELBO estimates. Also, depending on the problem, we may want to monitor some problem-specific diagnostics for monitoring the progress.

For both use cases above, defining a custom callback function can be useful. In this example, we will compute a more accurate estimate of the ELBO and the classification accuracy every logging_interval = 10 iterations.

using StatsFuns: StatsFuns

"""
    logistic_prediction(X, μ_β, Σ_β)

Approximate the posterior predictive probability for a logistic link function using Mackay's approximation (Bishop p. 220).
"""
function logistic_prediction(X, μ_β, Σ_β)
    xtΣx = sum((model.X*Σ_β) .* model.X; dims=2)[:, 1]
    κ = @. 1/sqrt(1 + π/8*xtΣx)
    return StatsFuns.logistic.(κ .* X*μ_β)
end

logging_interval = 100
function callback(; iteration, averaged_params, restructure, kwargs...)
    if mod(iteration, logging_interval) == 1

        # Use the averaged parameters (the eventual output of the algorithm)
        q_avg = restructure(averaged_params)

        # Compute predictions
        μ_β = mean(q_avg.dist)[1:(end - 1)] # posterior mean of β
        Σ_β = cov(q_avg.dist)[1:(end - 1), end - 1] # marginal posterior covariance of β
        y_pred = logistic_prediction(X, μ_β, Σ_β) .> 0.5

        # Prediction accuracy
        acc = mean(y_pred .== model.y)

        # Higher fidelity estimate of the ELBO on the averaged parameters
        n_samples = 256
        obj = AdvancedVI.RepGradELBO(n_samples)
        elbo_callback = estimate_objective(obj, q_avg, model)

        (elbo_callback=elbo_callback, accuracy=acc)
    else
        nothing
    end
end
nothing

Note that the interface for the callback function will depend on the VI algorithm being used. Therefore, please refer to the documentation of each VI algorithm.

The callback can be supplied to optimize:

max_iter = 10^4
q_out, info, _ = AdvancedVI.optimize(
    alg, max_iter, model_ad, q_transformed; show_progress=false, callback=callback
)
nothing

First, let's compare the default estimate of the ELBO, which uses a small number of samples and is evaluated in the current iterate, versus the ELBO computed in the callback, which uses a large number of samples and is evaluated on the averaged iterate.

t = 1:max_iter
elbo = [i.elbo for i in info[t]]

t_callback = 1:logging_interval:max_iter
elbo_callback = [i.elbo_callback for i in info[t_callback]]

plot(t, elbo; xlabel="Iteration", ylabel="ELBO", label="Default")
plot!(t_callback, elbo_callback; label="Callback", ylims=(-300, Inf), linewidth=2)

savefig("basic_example_elbo_callback.svg")
nothing

We can see that the default ELBO estimates are noisy compared to the higher fidelity estimates from the callback. After a few thousands of iterations, it is difficult to judge if we are still making progress or not. In contrast, the estimates from callback show that the objective is increasing smoothly.

Similarly, we can monitor the evolution of the prediction accuracy.

acc_callback = [i.accuracy for i in info[t_callback]]
plot(
    t_callback,
    acc_callback;
    xlabel="Iteration",
    ylabel="Prediction Accuracy",
    label=nothing,
)
savefig("basic_example_acc.svg")
nothing

Clearly, the accuracy is improving over time.

  • KTRGB2017Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research.
  • TL2014Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International Conference on Machine Learning. PMLR.
  • RMW2014Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In International Conference on Machine Learning. PMLR.
  • KW2014Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In International Conference on Learning Representations.