Stan Models

Since AdvancedVI supports the LogDensityProblem interface, it can also be used with Stan models through StanLogDensityProblems interface. Specifically, StanLogDensityProblems wraps any Stan model into a LogDensityProblem using BridgeStan.

Problem Setup

Recall the hierarchical logistic regression example in the Basic Example. Here, we will define the same model in Stan.

model_src = """
data {
  int<lower=0> N;
  int<lower=0> D;
  matrix[N,D] X;
  array[N] int<lower=0, upper=1> y;
}
parameters {
  vector[D] beta;
  real<lower=1e-4> sigma;
}
model {
  sigma ~ lognormal(0, 1);
  beta ~ normal(0, sigma);
  y ~ bernoulli_logit(X * beta);
}
"""
nothing

We also need to prepare the data.

using DataFrames: DataFrames
using OpenML: OpenML
using Statistics

data = Array(DataFrames.DataFrame(OpenML.load(40)))

X = Matrix{Float64}(data[:, 1:(end - 1)])
X = (X .- mean(X; dims=2)) ./ std(X; dims=2)
X = hcat(X, ones(size(X, 1)))
y = Vector{Int}(data[:, end] .== "Mine")

stan_data = (X=transpose(X), y=y, N=size(X, 1), D=size(X, 2))
nothing

Since StanLogDensityProblems expects files for both the model and the data, we need to store both on the file system.

using JSON: JSON

open("logistic_model.stan", "w") do io
    println(io, model_src)
end
open("logistic_data.json", "w") do io
    println(io, JSON.json(stan_data))
end
nothing

Inference via AdvancedVI

We can now call StanLogDensityProblems to recieve a LogDensityProblem.

using StanLogDensityProblems: StanLogDensityProblems

model = StanLogDensityProblems.StanProblem("logistic_model.stan", "logistic_data.json")
nothing
BridgeStan not found at location specified by $BRIDGESTAN environment variable, downloading version 2.6.2 to /home/runner/.bridgestan/bridgestan-2.6.2
Done!

The rest is the same as all LogDensityProblem with the exception of how to deal with constrainted variables: Since StanLogDensityProblems automatically transforms the support of the target problem to be unconstrained, we do not need to involve Bijectors.

using ADTypes, ReverseDiff
using AdvancedVI
using LinearAlgebra
using LogDensityProblems
using Plots

alg = KLMinRepGradDescent(ADTypes.AutoReverseDiff())

d = LogDensityProblems.dimension(model)
q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.37*I, d, d)))

max_iter = 10^4
q_out, info, _ = AdvancedVI.optimize(alg, max_iter, model, q; show_progress=false)

plot(
    [i.iteration for i in info],
    [i.elbo for i in info];
    xlabel="Iteration",
    ylabel="ELBO",
    label=nothing,
)
savefig("stan_example_elbo.svg")
"/home/runner/work/AdvancedVI.jl/AdvancedVI.jl/docs/build/tutorials/stan_example_elbo.svg"

From variational posterior q_out we can draw samples from the unconstrained support of the model. To convert the samples back to the original (constrained) support of the model, it suffices to call BridgeStan.param_constrain.