Implementing samplers

In this tutorial, we’ll go through step-by-step how to implement a “simple” sampler in AbstractMCMC.jl in such a way that it can be easily applied to Turing.jl models.

In particular, we’re going to implement a version of Metropolis-adjusted Langevin (MALA).

Note that we will implement this sampler in the AbstractMCMC.jl framework, completely “ignoring” Turing.jl until the very end of the tutorial, at which point we’ll use a single line of code to make the resulting sampler available to Turing.jl. This is to really drive home the point that one can implement samplers in a way that is accessible to all of Turing.jl’s users without having to use Turing.jl yourself.

Quick overview of MALA

We can view MALA as a single step of the leapfrog intergrator with resampling of momentum \(p\) at every step.1 To make that statement a bit more concrete, we first define the extended target \(\bar{\gamma}(x, p)\) as

\[\begin{equation*} \log \bar{\gamma}(x, p) \propto \log \gamma(x) + \log \gamma_{\mathcal{N}(0, M)}(p) \end{equation*}\]

where \(\gamma_{\mathcal{N}(0, M)}\) denotes the density for a zero-centered Gaussian with covariance matrix \(M\). We then consider targeting this joint distribution over both \(x\) and \(p\) as follows. First we define the map

\[\begin{equation*} \begin{split} L_{\epsilon}: \quad & \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}^d \times \mathbb{R}^d \\ & (x, p) \mapsto (\tilde{x}, \tilde{p}) := L_{\epsilon}(x, p) \end{split} \end{equation*}\]

as

\[\begin{equation*} \begin{split} p_{1 / 2} &:= p + \frac{\epsilon}{2} \nabla \log \gamma(x) \\ \tilde{x} &:= x + \epsilon M^{-1} p_{1 /2 } \\ p_1 &:= p_{1 / 2} + \frac{\epsilon}{2} \nabla \log \gamma(\tilde{x}) \\ \tilde{p} &:= - p_1 \end{split} \end{equation*}\]

This might be familiar for some readers as a single step of the Leapfrog integrator. We then define the MALA kernel as follows: given the current iterate \(x_i\), we sample the next iterate \(x_{i + 1}\) as

\[\begin{equation*} \begin{split} p &\sim \mathcal{N}(0, M) \\ (\tilde{x}, \tilde{p}) &:= L_{\epsilon}(x_i, p) \\ \alpha &:= \min \left\{ 1, \frac{\bar{\gamma}(\tilde{x}, \tilde{p})}{\bar{\gamma}(x_i, p)} \right\} \\ x_{i + 1} &:= \begin{cases} \tilde{x} \quad & \text{ with prob. } \alpha \\ x_i \quad & \text{ with prob. } 1 - \alpha \end{cases} \end{split} \end{equation*}\]

i.e. we accept the proposal \(\tilde{x}\) with probability \(\alpha\) and reject it, thus sticking with our current iterate, with probability \(1 - \alpha\).

What we need from a model: LogDensityProblems.jl

There are a few things we need from the “target” / “model” / density that we want to sample from:

  1. We need access to log-density evaluations \(\log \gamma(x)\) so we can compute the acceptance ratio involving \(\log \bar{\gamma}(x, p)\).
  2. We need access to log-density gradients \(\nabla \log \gamma(x)\) so we can compute the Leapfrog steps \(L_{\epsilon}(x, p)\).
  3. We also need access to the “size” of the model so we can determine the size of \(M\).

Luckily for us, there is a package called LogDensityProblems.jl which provides an interface for exactly this!

To demonstrate how one can implement the “LogDensityProblems.jl interface”2 we will use a simple Gaussian model as an example:

using LogDensityProblems: LogDensityProblems;

# Let's define some type that represents the model.
struct IsotropicNormalModel{M<:AbstractVector{<:Real}}
    "mean of the isotropic Gaussian"
    mean::M
end

# Specifies what input length the model expects.
LogDensityProblems.dimension(model::IsotropicNormalModel) = length(model.mean)
# Implementation of the log-density evaluation of the model.
function LogDensityProblems.logdensity(model::IsotropicNormalModel, x::AbstractVector{<:Real})
    return - sum(abs2, x .- model.mean) / 2
end

This gives us all of the properties we want for our MALA sampler with the exception of the computation of the gradient \(\nabla \log \gamma(x)\). There is the method LogDensityProblems.logdensity_and_gradient which should return a 2-tuple where the first entry is the evaluation of the logdensity \(\log \gamma(x)\) and the second entry is the gradient \(\nabla \log \gamma(x)\).

There are two ways to “implement” this method: 1) we implement it by hand, which is feasible in the case of our IsotropicNormalModel, or b) we defer the implementation of this to a automatic differentiation backend.

To implement it by hand we can simply do

# Tell LogDensityProblems.jl that first-order, i.e. gradient information, is available.
LogDensityProblems.capabilities(model::IsotropicNormalModel) = LogDensityProblems.LogDensityOrder{1}()

# Implement `logdensity_and_gradient`.
function LogDensityProblems.logdensity_and_gradient(model::IsotropicNormalModel, x)
    logγ_x = LogDensityProblems.logdensity(model, x)
    ∇logγ_x = -x .* (x - model.mean)
    return logγ_x, ∇logγ_x
end

Let’s just try it out:

# Instantiate the problem.
model = IsotropicNormalModel([-5., 0., 5.])
# Create some example input that we can test on.
x_example = randn(LogDensityProblems.dimension(model))
# Evaluate!
LogDensityProblems.logdensity(model, x_example)
-21.786101664128765

To defer it to an automatic differentiation backend, we can do

# Tell LogDensityProblems.jl we only have access to 0-th order information.
LogDensityProblems.capabilities(model::IsotropicNormalModel) = LogDensityProblems.LogDensityOrder{0}()

# Use `LogDensityProblemsAD`'s `ADgradient` in combination with some AD backend to implement `logdensity_and_gradient`.
using LogDensityProblemsAD, ADTypes, ForwardDiff
model_with_grad = ADgradient(AutoForwardDiff(), model)
LogDensityProblems.logdensity(model_with_grad, x_example)
-21.786101664128765

We’ll continue with the second approach in this tutorial since this is typically what one does in practice, because there are better hobbies to spend time on than deriving gradients by hand.

At this point, one might wonder how we’re going to tie this back to Turing.jl in the end. Effectively, when working with inference methods that only require log-density evaluations and / or higher-order information of the log-density, Turing.jl actually converts the user-provided Model into an object implementing the above methods for LogDensityProblems.jl. As a result, most samplers provided by Turing.jl are actually implemented to work with LogDensityProblems.jl, enabling their use both within Turing.jl and outside of Turing.jl! Morever, there exists similar conversions for Stan through BridgeStan and StanLogDensityProblems.jl, which means that a sampler supporting the LogDensityProblems.jl interface can easily be used on both Turing.jl and Stan models (in addition to user-provided models, as our IsotropicNormalModel above)!

Anyways, let’s move on to actually implementing the sampler.

Implementing MALA in AbstractMCMC.jl

Now that we’ve established that a model implementing the LogDensityProblems.jl interface provides us with all the information we need from \(\log \gamma(x)\), we can address the question: given an object that implements the LogDensityProblems.jl interface, how can we define a sampler for it?

We’re going to do this by making our sampler a sub-type of AbstractMCMC.AbstractSampler in addition to implementing a few methods from AbstractMCMC.jl. Why? Because it gets us a lot of functionality for free, as we will see later.

Moreover, AbstractMCMC.jl provides a very natural interface for MCMC algorithms.

First, we’ll define our MALA type

using AbstractMCMC

struct MALA{T,A} <: AbstractMCMC.AbstractSampler
    "stepsize used in the leapfrog step"
    ϵ_init::T
    "covariance matrix used for the momentum"
    M_init::A
end

Notice how we’ve added the suffix _init to both the stepsize and the covariance matrix. We’ve done this because a AbstractMCMC.AbstractSampler should be immutable. Of course there might be many scenarios where we want to allow something like the stepsize and / or the covariance matrix to vary between iterations, e.g. during the burn-in / adaptation phase of the sampling process we might want to adjust the parameters using statistics computed from these initial iterations. But information which can change between iterations should not go in the sampler itself! Instead, this information should go in the sampler state.

The sampler state should at the very least contain all the necessary information to perform the next MCMC iteration, but usually contains further information, e.g. quantities and statistics useful for evaluating whether the sampler has converged.

We will use the following sampler state for our MALA sampler:

struct MALAState{A<:AbstractVector{<:Real}}
    "current position"
    x::A
end

This might seem overly redundant: we’re defining a type MALAState and it only contains a simple vector of reals. In this particular case we indeed could have dropped this and simply used a AbstractVector{<:Real} as our sampler state, but typically, as we will see later, one wants to include other quantities in the sampler state. For example, if we also wanted to adapt the parameters of our MALA, e.g. alter the stepsize depending on acceptance rates, in which case we should also put ϵ in the state, but for now we’ll keep things simple.

Moreover, we also want a sample type, which is a type meant for “public consumption”, i.e. the end-user. This is generally going to contain a subset of the information present in the state. But in such a simple scenario as this, we similarly only have a AbstractVector{<:Real}:

struct MALASample{A<:AbstractVector{<:Real}}
    "current position"
    x::A
end

We currently have three things:

  1. A AbstractMCMC.AbstractSampler implementation called MALA.
  2. A state MALAState for our sampler MALA.
  3. A sample MALASample for our sampler MALA.

That means that we’re ready to implement the only thing that really matters: AbstractMCMC.step.

AbstractMCMC.step defines the MCMC iteration of our MALA given the current MALAState. Specifically, the signature of the function is as follows:

function AbstractMCMC.step(
    # The RNG to ensure reproducibility.
    rng::Random.AbstractRNG,
    # The model that defines our target.
    model::AbstractMCMC.AbstractModel,
    # The sampler for which we're taking a `step`.
    sampler::AbstractMCMC.AbstractSampler,
    # The current sampler `state`.
    state;
    # Additional keyword arguments that we may or may not need.
    kwargs...
)

Moreover, there is a specific AbstractMCMC.AbstractModel which is used to indicate that the model that is provided implements the LogDensityProblems.jl interface: AbstractMCMC.LogDensityModel.

Since, as we discussed earlier, in our case we’re indeed going to work with types that support the LogDensityProblems.jl interface, we’ll define AbstractMCMC.step for such a AbstractMCMC.LogDensityModel.

Note that AbstractMCMC.LogDensityModel has no other purpose; it has a single field called logdensity, and it does nothing else. But by wrapping the model in AbstractMCMC.LogDensityModel, it allows samplers that want to work with LogDensityProblems.jl to define their AbstractMCMC.step on this type without running into method ambiguities.

All in all, that means that the signature for our AbstractMCMC.step is going to be the following:

function AbstractMCMC.step(
    rng::Random.AbstractRNG,
    # `LogDensityModel` so we know we're working with LogDensityProblems.jl model.
    model::AbstractMCMC.LogDensityModel,
    # Our sampler.
    sampler::MALA,
    # Our sampler state.
    state::MALAState;
    kwargs...
)

Great! Now let’s actually implement the full AbstractMCMC.step for our MALA.

Let’s remind ourselves what we’re going to do:

  1. Sample a new momentum \(p\).
  2. Compute the log-density of the extended target \(\log \bar{\gamma}(x, p)\).
  3. Take a single leapfrog step \((\tilde{x}, \tilde{p}) = L_{\epsilon}(x, p)\).
  4. Accept or reject the proposed \((\tilde{x}, \tilde{p})\).

All in all, this results in the following:

using Random: Random
using Distributions  # so we get the `MvNormal`

function AbstractMCMC.step(
    rng::Random.AbstractRNG,
    model_wrapper::AbstractMCMC.LogDensityModel,
    sampler::MALA,
    state::MALAState;
    kwargs...
)
    # Extract the wrapped model which implements LogDensityProblems.jl.
    model = model_wrapper.logdensity
    # Let's just extract the sampler parameters to make our lives easier.
    ϵ = sampler.ϵ_init
    M = sampler.M_init
    # Extract the current parameters.
    x = state.x
    # Sample the momentum.
    p_dist = MvNormal(zeros(LogDensityProblems.dimension(model)), M)
    p = rand(rng, p_dist)
    # Propose using a single leapfrog step.
    x̃, p̃ = leapfrog_step(model, x, p, ϵ, M)
    # Accept or reject proposal.
    logp = LogDensityProblems.logdensity(model, x) + logpdf(p_dist, p)
    logp̃ = LogDensityProblems.logdensity(model, x̃) + logpdf(p_dist, p̃)
    logα = logp̃ - logp
    state_new = if log(rand(rng)) < logα
        # Accept.
        MALAState(x̃)
    else
        # Reject.
        MALAState(x)
    end
    # Return the "sample" and the sampler state.
    return MALASample(state_new.x), state_new
end

Fairly straight-forward.

Of course, we haven’t defined the leapfrog_step method yet, so let’s do that:

function leapfrog_step(model, x, p, ϵ, M)
    # Update momentum `p` using "position" `x`.
    ∇logγ_x = last(LogDensityProblems.logdensity_and_gradient(model, x))
    p1 = p +/ 2) .* ∇logγ_x
    # Update the "position" `x` using momentum `p1`.
= x + ϵ .* (M \ p1)
    # Update momentum `p1` using position `x̃`
    ∇logγ_x̃ = last(LogDensityProblems.logdensity_and_gradient(model, x̃))
    p2 = p1 +/ 2) .* ∇logγ_x̃
    # Flip momentum `p2`.
= -p2
    return x̃, p̃
end
leapfrog_step (generic function with 1 method)

With all of this, we’re technically ready to sample!

using Random, LinearAlgebra

rng = Random.default_rng()
sampler = MALA(1, I)
state = MALAState(zeros(LogDensityProblems.dimension(model)))

x_next, state_next = AbstractMCMC.step(
    rng,
    AbstractMCMC.LogDensityModel(model),
    sampler,
    state
)
(MALASample{Vector{Float64}}([0.0, 0.0, 0.0]), MALAState{Vector{Float64}}([0.0, 0.0, 0.0]))

Great, it works!

And I promised we would get quite some functionality for free if we implemented AbstractMCMC.step, and so we can now simply call sample to perform standard MCMC sampling:

# Perform 1000 iterations with our `MALA` sampler.
samples = sample(model_with_grad, sampler, 10_000; initial_state=state, progress=false)
# Concatenate into a matrix.
samples_matrix = stack(sample -> sample.x, samples)
3×10000 Matrix{Float64}:
 -2.37239    -3.81434  -4.58395   -5.13899   …  -4.86871  -4.32474   -3.94319
 -0.0508748   1.25924   0.540633   0.633697      0.50242  -0.699573   0.59933
  2.01533     4.79048   5.43006    5.72653       6.21491   6.2324     4.17817
# Compute the marginal means and standard deviations.
hcat(mean(samples_matrix; dims=2), std(samples_matrix; dims=2))
3×2 Matrix{Float64}:
 -4.98192     1.00236
 -0.00447641  1.00292
  5.00118     0.997667

Let’s visualize the samples

using StatsPlots
plot(transpose(samples_matrix[:, 1:10:end]), alpha=0.5, legend=false)

Look at that! Things are working; amazin’.

We can also exploit AbstractMCMC.jl’s parallel sampling capabilities:

# Run separate 4 chains for 10 000 iterations using threads to parallelize.
num_chains = 4
samples = sample(
    model_with_grad,
    sampler,
    MCMCThreads(),
    10_000,
    num_chains;
    # Note we need to provide an initial state for every chain.
    initial_state=fill(state, num_chains),
    progress=false
)
samples_array = stack(map(Base.Fix1(stack, sample -> sample.x), samples))
3×10000×4 Array{Float64, 3}:
[:, :, 1] =
 -2.23671      -4.61949    -4.85619  …  -5.73974  -4.52126    -4.42248
  0.000882482   0.0295248   1.51016     -0.20179  -0.0421941   0.228151
  1.39739       2.59358     3.76588      3.48055   3.85493     3.4583

[:, :, 2] =
 -2.43789   -2.91765   -5.05763  -6.22407  …  -5.11149   -3.85071  -3.03706
  0.285881  -0.524364  -1.32892  -1.04882      0.302992   1.39856  -0.48604
  3.64679    4.89439    3.74543   5.03001      3.28274    4.87009   4.79716

[:, :, 3] =
 -2.0237    -3.45489   -3.97208  -4.0561   …  -2.59549  -4.13747  -4.46199
  0.773427   0.224835   0.15764   2.03776      1.23189   2.61799   2.49639
  3.35692    4.35895    4.64349   5.34479      5.11481   4.88716   6.34721

[:, :, 4] =
 -1.85668   -3.63119   -3.63119   -3.64468   …  -5.29111  -4.4255   -4.41179
 -0.692523  -0.169537  -0.169537  -0.910128     -1.03701   0.80509   0.501866
  3.81451    3.14381    3.14381    3.69604       3.54577   4.18928   5.43837

But the fact that we have to provide the AbstractMCMC.sample call, etc. with an initial_state to get started is a bit annoying. We can avoid this by also defining a AbstractMCMC.step without the state argument:

function AbstractMCMC.step(
    rng::Random.AbstractRNG,
    model_wrapper::AbstractMCMC.LogDensityModel,
    ::MALA;
    # NOTE: No state provided!
    kwargs...
)
    model = model_wrapper.logdensity
    # Let's just create the initial state by sampling using  a Gaussian.
    x = randn(rng, LogDensityProblems.dimension(model))

    return MALASample(x), MALAState(x)
end

Equipped with this, we no longer need to provide the initial_state everywhere:

samples = sample(model_with_grad, sampler, 10_000; progress=false)
samples_matrix = stack(sample -> sample.x, samples)
hcat(mean(samples_matrix; dims=2), std(samples_matrix; dims=2))
3×2 Matrix{Float64}:
 -5.01084      1.0006
  0.000195093  0.99597
  4.99267      1.00079

Using our sampler with Turing.jl

As we promised, all of this hassle of implementing our MALA sampler in a way that uses LogDensityProblems.jl and AbstractMCMC.jl gets us something more than just an “automatic” implementation of AbstractMCMC.sample.

It also enables use with Turing.jl through the externalsampler, but we need to do one final thing first: we need to tell Turing.jl how to extract a vector of parameters from the “sample” returned in our implementation of AbstractMCMC.step. In our case, the “sample” is a MALASample, so we just need the following line:

# Load Turing.jl.
using Turing

# Overload the `getparams` method for our "sample" type, which is just a vector.
Turing.Inference.getparams(::Turing.Model, sample::MALASample) = sample.x

And with that, we’re good to go!

# Our previous model defined as a Turing.jl model.
@model mvnormal_model() = x ~ MvNormal([-5., 0., 5.], I)
# Instantiate our model.
turing_model = mvnormal_model()
# Call `sample` but now we're passing in a Turing.jl `model` and wrapping
# our `MALA` sampler in the `externalsampler` to tell Turing.jl that the sampler
# expects something that implements LogDensityProblems.jl.
chain = sample(turing_model, externalsampler(sampler), 10_000; progress=false)
Chains MCMC chain (10000×4×1 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 1
Samples per chain = 10000
Wall duration     = 2.4 seconds
Compute duration  = 2.4 seconds
parameters        = x[1], x[2], x[3]
internals         = lp

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

        x[1]   -5.0092    1.0120    0.0187   2932.5831   4727.9927    1.0001   ⋯
        x[2]   -0.0246    1.0040    0.0174   3331.1056   5228.2384    1.0003   ⋯
        x[3]    4.9856    0.9966    0.0166   3620.4030   5887.4580    1.0000   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

        x[1]   -7.0119   -5.6896   -5.0125   -4.3353   -2.9939
        x[2]   -1.9859   -0.6834   -0.0191    0.6528    1.9025
        x[3]    3.0361    4.3124    4.9806    5.6685    6.9305

Pretty neat, eh?

Models with constrained parameters

One thing we’ve sort of glossed over in all of the above is that MALA, at least how we’ve implemented it, requires \(x\) to live in \(\mathbb{R}^d\) for some \(d > 0\). If some of the parameters were in fact constrained, e.g. we were working with a Beta distribution which has support on the interval \((0, 1)\), not on \(\mathbb{R}^d\), we could easily end up outside of the valid range \((0, 1)\).

@model beta_model() = x ~ Beta(3, 3)
turing_model = beta_model()
chain = sample(turing_model, externalsampler(sampler), 10_000; progress=false)
Chains MCMC chain (10000×2×1 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 1
Samples per chain = 10000
Wall duration     = 1.41 seconds
Compute duration  = 1.41 seconds
parameters        = x
internals         = lp

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

           x    0.4981    0.1902    0.0029   4295.6714   5170.1277    1.0000   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           x    0.1414    0.3572    0.4972    0.6391    0.8479

Yep, that still works, but only because Turing.jl actually transforms the turing_model from constrained to unconstrained, so that the sampler provided to externalsampler is actually always working in unconstrained space! This is not always desirable, so we can turn this off:

chain = sample(turing_model, externalsampler(sampler; unconstrained=false), 10_000; progress=false)
Chains MCMC chain (10000×2×1 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 1
Samples per chain = 10000
Wall duration     = 0.23 seconds
Compute duration  = 0.23 seconds
parameters        = x
internals         = lp

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           x    0.5102    0.1587    0.0183    68.8852    52.8682    1.0419     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           x    0.2309    0.3870    0.5116    0.6378    0.7617

The fun thing is that this still sort of works because

logpdf(Beta(3, 3), 10.0)
-Inf

and so the samples that fall outside of the range are always rejected. But do notice how much worse all the diagnostics are, e.g. ess_tail is very poor compared to when we use unconstrained=true. Moreover, in more complex cases this won’t just result in a “nice” -Inf log-density value, but instead will error:

@model function demo()
    σ² ~ truncated(Normal(), lower=0)
    # If we end up with negative values for `σ²`, the `Normal` will error.
    x ~ Normal(0, σ²)
end
sample(demo(), externalsampler(sampler; unconstrained=false), 10_000; progress=false)
DomainError: DomainError(-0.6987129066579123, "Normal: the condition σ >= zero(σ) is not satisfied.")
DomainError with -0.6987129066579123:
Normal: the condition σ >= zero(σ) is not satisfied.
Stacktrace:
  [1] #371
    @ ~/.julia/packages/Distributions/uuqsE/src/univariate/continuous/normal.jl:37 [inlined]
  [2] check_args
    @ ~/.julia/packages/Distributions/uuqsE/src/utils.jl:89 [inlined]
  [3] #Normal#370
    @ ~/.julia/packages/Distributions/uuqsE/src/univariate/continuous/normal.jl:37 [inlined]
  [4] Normal
    @ ~/.julia/packages/Distributions/uuqsE/src/univariate/continuous/normal.jl:36 [inlined]
  [5] Normal
    @ ~/.julia/packages/Distributions/uuqsE/src/univariate/continuous/normal.jl:42 [inlined]
  [6] macro expansion
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/compiler.jl:579 [inlined]
  [7] demo(__model__::DynamicPPL.Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.TypedVarInfo{@NamedTuple{σ²::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ², typeof(identity)}, Int64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Vector{AbstractPPL.VarName{:σ², typeof(identity)}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, x::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, typeof(identity)}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, typeof(identity)}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}, Vector{Base.RefValue{Float64}}}, __context__::DynamicPPL.ValuesAsInModelContext{OrderedDict{Any, Any}, DynamicPPL.DefaultContext})
    @ Main.Notebook ~/work/docs/docs/tutorials/docs-17-implementing-samplers/index.qmd:593
  [8] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/model.jl:973 [inlined]
  [9] evaluate_threadsafe!!(model::DynamicPPL.Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, varinfo::DynamicPPL.TypedVarInfo{@NamedTuple{σ²::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ², typeof(identity)}, Int64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Vector{AbstractPPL.VarName{:σ², typeof(identity)}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, x::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, typeof(identity)}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, typeof(identity)}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}, context::DynamicPPL.ValuesAsInModelContext{OrderedDict{Any, Any}, DynamicPPL.DefaultContext})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/DvdZw/src/model.jl:962
 [10] evaluate!!
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/model.jl:892 [inlined]
 [11] values_as_in_model
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/values_as_in_model.jl:196 [inlined]
 [12] values_as_in_model
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/values_as_in_model.jl:195 [inlined]
 [13] getparams(model::DynamicPPL.Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, vi::DynamicPPL.TypedVarInfo{@NamedTuple{σ²::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ², typeof(identity)}, Int64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Vector{AbstractPPL.VarName{:σ², typeof(identity)}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, x::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, typeof(identity)}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, typeof(identity)}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64})
    @ Turing.Inference ~/.julia/packages/Turing/QN7BL/src/mcmc/Inference.jl:343
 [14] Turing.Inference.Transition(model::DynamicPPL.Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, vi::DynamicPPL.TypedVarInfo{@NamedTuple{σ²::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ², typeof(identity)}, Int64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Vector{AbstractPPL.VarName{:σ², typeof(identity)}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, x::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, typeof(identity)}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, typeof(identity)}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}, t::MALASample{Vector{Float64}})
    @ Turing.Inference ~/.julia/packages/Turing/QN7BL/src/mcmc/Inference.jl:226
 [15] transition_to_turing
    @ ~/.julia/packages/Turing/QN7BL/src/mcmc/abstractmcmc.jl:12 [inlined]
 [16] transition_to_turing(f::LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{LogDensityFunction{DynamicPPL.TypedVarInfo{@NamedTuple{σ²::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ², typeof(identity)}, Int64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Vector{AbstractPPL.VarName{:σ², typeof(identity)}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, x::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, typeof(identity)}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, typeof(identity)}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}, DynamicPPL.Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, Nothing}, ForwardDiff.Chunk{2}, ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, ForwardDiff.GradientConfig{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 2}}}}, transition::MALASample{Vector{Float64}})
    @ Turing.Inference ~/.julia/packages/Turing/QN7BL/src/mcmc/abstractmcmc.jl:17
 [17] step(rng::TaskLocalRNG, model::DynamicPPL.Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, sampler_wrapper::DynamicPPL.Sampler{Turing.Inference.ExternalSampler{MALA{Int64, UniformScaling{Bool}}, AutoForwardDiff{nothing, Nothing}, false}}; initial_state::Nothing, initial_params::Nothing, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/QN7BL/src/mcmc/abstractmcmc.jl:147
 [18] macro expansion
    @ ~/.julia/packages/AbstractMCMC/z1nwm/src/sample.jl:0 [inlined]
 [19] macro expansion
    @ ~/.julia/packages/AbstractMCMC/z1nwm/src/logging.jl:16 [inlined]
 [20] mcmcsample(rng::TaskLocalRNG, model::DynamicPPL.Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, sampler::DynamicPPL.Sampler{Turing.Inference.ExternalSampler{MALA{Int64, UniformScaling{Bool}}, AutoForwardDiff{nothing, Nothing}, false}}, N::Int64; progress::Bool, progressname::String, callback::Nothing, num_warmup::Int64, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/z1nwm/src/sample.jl:142
 [21] sample(rng::TaskLocalRNG, model::DynamicPPL.Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, sampler::DynamicPPL.Sampler{Turing.Inference.ExternalSampler{MALA{Int64, UniformScaling{Bool}}, AutoForwardDiff{nothing, Nothing}, false}}, N::Int64; chain_type::Type, resume_from::Nothing, initial_state::Nothing, kwargs::@Kwargs{progress::Bool})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/DvdZw/src/sampler.jl:93
 [22] sample
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/sampler.jl:83 [inlined]
 [23] #sample#4
    @ ~/.julia/packages/Turing/QN7BL/src/mcmc/Inference.jl:276 [inlined]
 [24] sample
    @ ~/.julia/packages/Turing/QN7BL/src/mcmc/Inference.jl:267 [inlined]
 [25] #sample#3
    @ ~/.julia/packages/Turing/QN7BL/src/mcmc/Inference.jl:264 [inlined]
 [26] top-level scope
    @ ~/work/docs/docs/tutorials/docs-17-implementing-samplers/index.qmd:595

As expected, we run into a DomainError at some point, while if we set unconstrained=true, letting Turing.jl transform the model to a unconstrained form behind the scenes, everything works as expected:

sample(demo(), externalsampler(sampler; unconstrained=true), 10_000; progress=false)
Chains MCMC chain (10000×3×1 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 1
Samples per chain = 10000
Wall duration     = 1.58 seconds
Compute duration  = 1.58 seconds
parameters        = σ², x
internals         = lp

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

          σ²    0.4675    0.0000    0.0000        NaN        NaN       NaN     ⋯
           x    1.7741    0.0000    0.0000        NaN        NaN       NaN     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

          σ²    0.4675    0.4675    0.4675    0.4675    0.4675
           x    1.7741    1.7741    1.7741    1.7741    1.7741

Neat!

Similarly, which automatic differentiation backend one should use can be specified through the adtype keyword argument too. For example, if we want to use ReverseDiff.jl instead of the default ForwardDiff.jl:

using ReverseDiff: ReverseDiff
# Specify that we want to use `AutoReverseDiff`.
sample(
    demo(),
    externalsampler(sampler; unconstrained=true, adtype=AutoReverseDiff()),
    10_000;
    progress=false
)
Chains MCMC chain (10000×3×1 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 1
Samples per chain = 10000
Wall duration     = 2.84 seconds
Compute duration  = 2.84 seconds
parameters        = σ², x
internals         = lp

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

          σ²    0.7070    0.5959    0.0730    65.8346    71.8390    1.0325     ⋯
           x    0.3571    1.0278    0.0849   114.0333   168.5240    1.0149     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

          σ²    0.0952    0.2558    0.4932    1.0274    2.2595
           x   -1.8003   -0.1313    0.3461    0.7528    2.7833

Double-neat.

Summary

At this point it’s worth maybe reminding ourselves what we did and also why we did it:

  1. We define our models in the LogDensityProblems.jl interface because it makes the sampler agnostic to how the underlying model is implemented.
  2. We implement our sampler in the AbstractMCMC.jl interface, which just means that our sampler is a subtype of AbstractMCMC.AbstractSampler and we implement the MCMC transition in AbstractMCMC.step.
  3. Points 1 and 2 makes it so our sampler can be used with a wide range of model implementations, amongst them being models implemented in both Turing.jl and Stan. This gives you, the inference implementer, a large collection of models to test your inference method on, in addition to allowing users of Turing.jl and Stan to try out your inference method with minimal effort.
Back to top

Footnotes

  1. We’re going with the leapfrog formulation because in a future version of this tutorial we’ll add a section extending this simple “baseline” MALA sampler to more complex versions. See issue #479 for progress on this.↩︎

  2. There is no such thing as a proper interface in Julia (at least not officially), and so we use the word “interface” here to mean a few minimal methods that needs to be implemented by any type that we treat as a target model.↩︎