The sampling interface
Turing implements a sampling interface (hosted at AbstractMCMC) that is intended to provide a common framework for Markov chain Monte Carlo samplers. The interface presents several structures and functions that one needs to overload in order to implement an interface-compatible sampler.
This guide will demonstrate how to implement the interface without Turing.
Interface overview
Any implementation of an inference method that uses the AbstractMCMC interface should implement a subset of the following types and functions:
- A subtype of
AbstractSampler
, defined as a mutable struct containing state information or sampler parameters. - A function
sample_init!
which performs any necessary set-up (default: do not perform any set-up). - A function
step!
which returns a transition that represents a single draw from the sampler. - A function
transitions_init
which returns a container for the transitions obtained from the sampler (default: return aVector{T}
of lengthN
whereT
is the type of the transition obtained in the first step andN
is the number of requested samples). - A function
transitions_save!
which saves transitions to the container (default: save the transition of iterationi
at positioni
in the vector of transitions). - A function
sample_end!
which handles any sampler wrap-up (default: do not perform any wrap-up). - A function
bundle_samples
which accepts the container of transitions and returns a collection of samples (default: return the vector of transitions).
The interface methods with exclamation points are those that are intended to allow for state mutation. Any mutating function is meant to allow mutation where needed -- you might use:
sample_init!
to run some kind of sampler preparation, before sampling begins. This could mutate a sampler's state.step!
might mutate a sampler flag after each sample.sample_end!
contains any wrap-up you might need to do. If you were sampling in a transformed space, this might be where you convert everything back to a constrained space.
Why do you have an interface?
The motivation for the interface is to allow Julia's fantastic probabilistic programming language community to have a set of standards and common implementations so we can all thrive together. Markov chain Monte Carlo methods tend to have a very similar framework to one another, and so a common interface should help more great inference methods built in single-purpose packages to experience more use among the community.
Implementing Metropolis-Hastings without Turing
Metropolis-Hastings is often the first sampling method that people are exposed to. It is a very straightforward algorithm and is accordingly the easiest to implement, so it makes for a good example. In this section, you will learn how to use the types and functions listed above to implement the Metropolis-Hastings sampler using the MCMC interface.
The full code for this implementation is housed in AdvancedMH.jl.
Imports
Let's begin by importing the relevant libraries. We'll import AbstractMCMC
, which contains
the interface framework we'll fill out. We also need Distributions
and Random
.
# Import the relevant libraries.
using AbstractMCMC: AbstractMCMC
using Distributions
using Random
An interface extension (like the one we're writing right now) typically requires that you overload or implement several functions. Specifically, you should import
the functions you intend to overload. This next code block accomplishes that.
From Distributions
, we need Sampleable
, VariateForm
, and ValueSupport
, three abstract types that define a distribution. Models in the interface are assumed to be subtypes of Sampleable{VariateForm, ValueSupport}
. In this section our model is going be be extremely simple, so we will not end up using these except to make sure that the inference functions are dispatching correctly.
Sampler
Let's begin our sampler definition by defining a sampler called MetropolisHastings
which
is a subtype of AbstractSampler
. Correct typing is very important for proper interface
implementation -- if you are missing a subtype, your method may not be dispatched to when
you call sample
.
# Define a sampler type.
struct MetropolisHastings{T,D} <: AbstractMCMC.AbstractSampler
init_θ::T
proposal::D
end
# Default constructors.
MetropolisHastings(init_θ::Real) = MetropolisHastings(init_θ, Normal(0, 1))
function MetropolisHastings(init_θ::Vector{<:Real})
return MetropolisHastings(init_θ, MvNormal(zero(init_θ), I))
end
Main.var"##WeaveSandBox#293".MetropolisHastings
Above, we have defined a sampler that stores the initial parameterization of the prior, and a distribution object from which proposals are drawn. You can have a struct that has no fields, and simply use it for dispatching onto the relevant functions, or you can store a large amount of state information in your sampler.
The general intuition for what to store in your sampler struct is that anything you may
need to perform inference between samples but you don't want to store in a transition
should go into the sampler struct. It's the only way you can carry non-sample related state
information between step!
calls.
Model
Next, we need to have a model of some kind. A model is a struct that's a subtype of
AbstractModel
that contains whatever information is necessary to perform inference on
your problem. In our case we want to know the mean and variance parameters for a standard
Normal distribution, so we can keep our model to the log density of a Normal.
Note that we only have to do this because we are not yet integrating the sampler with Turing -- Turing has a very sophisticated modelling engine that removes the need to define custom model structs.
# Define a model type. Stores the log density function.
struct DensityModel{F<:Function} <: AbstractMCMC.AbstractModel
ℓπ::F
end
Transition
The next step is to define some transition which we will return from each step!
call.
We'll keep it simple by just defining a wrapper struct that contains the parameter draws
and the log density of that draw:
# Create a very basic Transition type, only stores the
# parameter draws and the log probability of the draw.
struct Transition{T,L}
θ::T
lp::L
end
# Store the new draw and its log density.
Transition(model::DensityModel, θ) = Transition(θ, ℓπ(model, θ))
Main.var"##WeaveSandBox#293".Transition
Transition
can now store any type of parameter, whether it's a vector of draws from
multiple parameters or a single univariate draw.
Metropolis-Hastings
Now it's time to get into the actual inference. We've defined all of the core pieces we
need, but we need to implement the step!
function which actually performs inference.
As a refresher, Metropolis-Hastings implements a very basic algorithm:
-
Pick some initial state,
\theta_0
. -
For
t
in[1,N],
do- Generate a proposal parameterization
\theta^\prime_t \sim q(\theta^\prime_t \mid \theta_{t-1}).
- Calculate the acceptance probability,
\alpha = \text{min}\left[1,\frac{\pi(\theta'_t)}{\pi(\theta_{t-1})} \frac{q(\theta_{t-1} \mid \theta'_t)}{q(\theta'_t \mid \theta_{t-1})}) \right].
- If
U \le \alpha
whereU \sim [0,1],
then\theta_t = \theta'_t.
Otherwise,\theta_t = \theta_{t-1}.
- Generate a proposal parameterization
Of course, it's much easier to do this in the log space, so the acceptance probability is more commonly written as
>\log \alpha = \min\left[0, \log \pi(\theta'_t) - \log \pi(\theta_{t-1}) + \log q(\theta_{t-1} \mid \theta^\prime_t) - \log q(\theta\prime_t \mid \theta_{t-1}) \right].
In interface terms, we should do the following:
- Make a new transition containing a proposed sample.
- Calculate the acceptance probability.
- If we accept, return the new transition, otherwise, return the old one.
Steps
The step!
function is the function that performs the bulk of your inference. In our case,
we will implement two step!
functions -- one for the very first iteration, and one for
every subsequent iteration.
# Define the first step! function, which is called at the
# beginning of sampling. Return the initial parameter used
# to define the sampler.
function AbstractMCMC.step!(
rng::AbstractRNG,
model::DensityModel,
spl::MetropolisHastings,
N::Integer,
::Nothing;
kwargs...,
)
return Transition(model, spl.init_θ)
end
The first step!
function just packages up the initial parameterization inside the
sampler, and returns it. We implicitly accept the very first parameterization.
The other step!
function performs the usual steps from Metropolis-Hastings. Included are
several helper functions, proposal
and q
, which are designed to replicate the functions
in the pseudocode above.
proposal
generates a new proposal in the form of aTransition
, which can be univariate if the value passed in is univariate, or it can be multivariate if theTransition
given is multivariate. Proposals use a basicNormal
orMvNormal
proposal distribution.q
returns the log density of one parameterization conditional on another, according to the proposal distribution.step!
generates a new proposal, checks the acceptance probability, and then returns either the previous transition or the proposed transition.
# Define a function that makes a basic proposal depending on a univariate
# parameterization or a multivariate parameterization.
function propose(spl::MetropolisHastings, model::DensityModel, θ::Real)
return Transition(model, θ + rand(spl.proposal))
end
function propose(spl::MetropolisHastings, model::DensityModel, θ::Vector{<:Real})
return Transition(model, θ + rand(spl.proposal))
end
function propose(spl::MetropolisHastings, model::DensityModel, t::Transition)
return propose(spl, model, t.θ)
end
# Calculates the probability `q(θ|θcond)`, using the proposal distribution `spl.proposal`.
q(spl::MetropolisHastings, θ::Real, θcond::Real) = logpdf(spl.proposal, θ - θcond)
function q(spl::MetropolisHastings, θ::Vector{<:Real}, θcond::Vector{<:Real})
return logpdf(spl.proposal, θ - θcond)
end
q(spl::MetropolisHastings, t1::Transition, t2::Transition) = q(spl, t1.θ, t2.θ)
# Calculate the density of the model given some parameterization.
ℓπ(model::DensityModel, θ) = model.ℓπ(θ)
ℓπ(model::DensityModel, t::Transition) = t.lp
# Define the other step function. Returns a Transition containing
# either a new proposal (if accepted) or the previous proposal
# (if not accepted).
function AbstractMCMC.step!(
rng::AbstractRNG,
model::DensityModel,
spl::MetropolisHastings,
::Integer,
θ_prev::Transition;
kwargs...,
)
# Generate a new proposal.
θ = propose(spl, model, θ_prev)
# Calculate the log acceptance probability.
α = ℓπ(model, θ) - ℓπ(model, θ_prev) + q(spl, θ_prev, θ) - q(spl, θ, θ_prev)
# Decide whether to return the previous θ or the new one.
if log(rand(rng)) < min(α, 0.0)
return θ
else
return θ_prev
end
end
Chains
In the default implementation, sample
just returns a vector of all transitions. If
instead you would like to obtain a Chains
object (e.g., to simplify downstream analysis),
you have to implement the bundle_samples
function as well. It accepts the vector of
transitions and returns a collection of samples. Fortunately, our Transition
is
incredibly simple, and we only need to build a little bit of functionality to accept custom
parameter names passed in by the user.
# A basic chains constructor that works with the Transition struct we defined.
function AbstractMCMC.bundle_samples(
rng::AbstractRNG,
ℓ::DensityModel,
s::MetropolisHastings,
N::Integer,
ts::Vector{<:Transition},
chain_type::Type{Any};
param_names=missing,
kwargs...,
)
# Turn all the transitions into a vector-of-vectors.
vals = copy(reduce(hcat, [vcat(t.θ, t.lp) for t in ts])')
# Check if we received any parameter names.
if ismissing(param_names)
param_names = ["Parameter $i" for i in 1:(length(first(vals)) - 1)]
end
# Add the log density field to the parameter names.
push!(param_names, "lp")
# Bundle everything up and return a Chains struct.
return Chains(vals, param_names, (internals=["lp"],))
end
All done!
You can even implement different output formats by implementing bundle_samples
for
different chain_type
s, which can be provided as keyword argument to sample
. As default
sample
uses chain_type = Any
.
Testing the implementation
Now that we have all the pieces, we should test the implementation by defining a model to
calculate the mean and variance parameters of a Normal distribution. We can do this by
constructing a target density function, providing a sample of data, and then running the
sampler with sample
.
# Generate a set of data from the posterior we want to estimate.
data = rand(Normal(5, 3), 30)
# Define the components of a basic model.
insupport(θ) = θ[2] >= 0
dist(θ) = Normal(θ[1], θ[2])
density(θ) = insupport(θ) ? sum(logpdf.(dist(θ), data)) : -Inf
# Construct a DensityModel.
model = DensityModel(density)
# Set up our sampler with initial parameters.
spl = MetropolisHastings([0.0, 0.0])
# Sample from the posterior.
chain = sample(model, spl, 100000; param_names=["μ", "σ"])
If all the interface functions have been extended properly, you should get an output from
display(chain)
that looks something like this:
Object of type Chains, with data of type 100000×3×1 Array{Float64,3}
Iterations = 1:100000
Thinning interval = 1
Chains = 1
Samples per chain = 100000
internals = lp
parameters = μ, σ
2-element Array{ChainDataFrame,1}
Summary Statistics
│ Row │ parameters │ mean │ std │ naive_se │ mcse │ ess │ r_hat │
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Any │ Any │
├─────┼────────────┼─────────┼──────────┼────────────┼────────────┼─────────┼─────────┤
│ 1 │ μ │ 5.33157 │ 0.854193 │ 0.0027012 │ 0.00893069 │ 8344.75 │ 1.00009 │
│ 2 │ σ │ 4.54992 │ 0.632916 │ 0.00200146 │ 0.00534942 │ 14260.8 │ 1.00005 │
Quantiles
│ Row │ parameters │ 2.5% │ 25.0% │ 50.0% │ 75.0% │ 97.5% │
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │
├─────┼────────────┼─────────┼─────────┼─────────┼─────────┼─────────┤
│ 1 │ μ │ 3.6595 │ 4.77754 │ 5.33182 │ 5.89509 │ 6.99651 │
│ 2 │ σ │ 3.5097 │ 4.09732 │ 4.47805 │ 4.93094 │ 5.96821 │
It looks like we're extremely close to our true parameters of Normal(5,3)
, though with a
fairly high variance due to the low sample size.
Conclusion
We've seen how to implement the sampling interface for general projects. Turing's interface methods are ever-evolving, so please open an issue at AbstractMCMC with feature requests or problems.