A Mini Turing Implementation I: Compiler

In this tutorial we develop a very simple probabilistic programming language. The implementation is similar to DynamicPPL. This is intentional as we want to demonstrate some key ideas from Turing’s internal implementation.

To make things easy to understand and to implement we restrict our language to a very simple subset of the language that Turing actually supports. Defining an accurate syntax description is not our goal here, instead, we give a simple example and all similar programs should work.

Consider a probabilistic model defined by

\[ \begin{aligned} a &\sim \operatorname{Normal}(0.5, 1^2) \\ b &\sim \operatorname{Normal}(a, 2^2) \\ x &\sim \operatorname{Normal}(b, 0.5^2) \end{aligned} \]

We assume that x is data, i.e., an observed variable. In our small language this model will be defined as

@mini_model function m(x)
    a ~ Normal(0.5, 1)
    b ~ Normal(a, 2)
    x ~ Normal(b, 0.5)
    return nothing
end

Specifically, we demand that

  • all observed variables are arguments of the program,
  • the model definition does not contain any control flow,
  • all variables are scalars, and
  • the function returns nothing.

First, we import some required packages:

using MacroTools, Distributions, Random, AbstractMCMC, MCMCChains

Before getting to the actual “compiler”, we first build the data structure for the program trace. A program trace for a probabilistic programming language needs to at least record the values of stochastic variables and their log-probabilities.

struct VarInfo{V,L}
    values::V
    logps::L
end

VarInfo() = VarInfo(Dict{Symbol,Float64}(), Dict{Symbol,Float64}())

function Base.setindex!(varinfo::VarInfo, (value, logp), var_id)
    varinfo.values[var_id] = value
    varinfo.logps[var_id] = logp
    return varinfo
end

Internally, our probabilistic programming language works with two main functions:

  • assume for sampling unobserved variables and computing their log-probabilities, and
  • observe for computing log-probabilities of observed variables (but not sampling them).

For different inference algorithms we may have to use different sampling procedures and different log-probability computations. For instance, in some cases we might want to sample all variables from their prior distributions and in other cases we might only want to compute the log-likelihood of the observations based on a given set of values for the unobserved variables. Thus depending on the inference algorithm we want to use different assume and observe implementations. We can achieve this by providing this context information as a function argument to assume and observe.

Note: Although the context system in this tutorial is inspired by DynamicPPL, it is very simplistic. We expand this mini Turing example in the contexts tutorial with some more complexity, to illustrate how and why contexts are central to Turing’s design. For the full details one still needs to go to the actual source of DynamicPPL though.

Here we can see the implementation of a sampler that draws values of unobserved variables from the prior and computes the log-probability for every variable.

struct SamplingContext{S<:AbstractMCMC.AbstractSampler,R<:Random.AbstractRNG}
    rng::R
    sampler::S
end

struct PriorSampler <: AbstractMCMC.AbstractSampler end

function observe(context::SamplingContext, varinfo, dist, var_id, var_value)
    logp = logpdf(dist, var_value)
    varinfo[var_id] = (var_value, logp)
    return nothing
end

function assume(context::SamplingContext{PriorSampler}, varinfo, dist, var_id)
    sample = Random.rand(context.rng, dist)
    logp = logpdf(dist, sample)
    varinfo[var_id] = (sample, logp)
    return sample
end;

Next we define the “compiler” for our simple programming language. The term compiler is actually a bit misleading here since its only purpose is to transform the function definition in the @mini_model macro by

  • adding the context information (context) and the tracing data structure (varinfo) as additional arguments, and
  • replacing tildes with calls to assume and observe.

Afterwards, as usual the Julia compiler will just-in-time compile the model function when it is called.

The manipulation of Julia expressions is an advanced part of the Julia language. The Julia documentation provides an introduction to and more details about this so-called metaprogramming.

macro mini_model(expr)
    return esc(mini_model(expr))
end

function mini_model(expr)
    # Split the function definition into a dictionary with its name, arguments, body etc.
    def = MacroTools.splitdef(expr)

    # Replace tildes in the function body with calls to `assume` or `observe`
    def[:body] = MacroTools.postwalk(def[:body]) do sub_expr
        if MacroTools.@capture(sub_expr, var_ ~ dist_)
            if var in def[:args]
                # If the variable is an argument of the model function, it is observed
                return :($(observe)(context, varinfo, $dist, $(Meta.quot(var)), $var))
            else
                # Otherwise it is unobserved
                return :($var = $(assume)(context, varinfo, $dist, $(Meta.quot(var))))
            end
        else
            return sub_expr
        end
    end

    # Add `context` and `varinfo` arguments to the model function
    def[:args] = vcat(:varinfo, :context, def[:args])

    # Reassemble the function definition from its name, arguments, body etc.
    return MacroTools.combinedef(def)
end;

For inference, we make use of the AbstractMCMC interface. It provides a default implementation of a sample function for sampling a Markov chain. The default implementation already supports e.g. sampling of multiple chains in parallel, thinning of samples, or discarding initial samples.

The AbstractMCMC interface requires us to at least

  • define a model that is a subtype of AbstractMCMC.AbstractModel,
  • define a sampler that is a subtype of AbstractMCMC.AbstractSampler,
  • implement AbstractMCMC.step for our model and sampler.

Thus here we define a MiniModel model. In this model we store the model function and the observed data.

struct MiniModel{F,D} <: AbstractMCMC.AbstractModel
    f::F
    data::D # a NamedTuple of all the data
end

In the Turing compiler, the model-specific DynamicPPL.Model is constructed automatically when calling the model function. But for the sake of simplicity here we construct the model manually.

To illustrate probabilistic inference with our mini language we implement an extremely simplistic Random-Walk Metropolis-Hastings sampler. We hard-code the proposal step as part of the sampler and only allow normal distributions with zero mean and fixed standard deviation. The Metropolis-Hastings sampler in Turing is more flexible.

struct MHSampler{T<:Real} <: AbstractMCMC.AbstractSampler
    sigma::T
end

MHSampler() = MHSampler(1)

function assume(context::SamplingContext{<:MHSampler}, varinfo, dist, var_id)
    sampler = context.sampler
    old_value = varinfo.values[var_id]

    # propose a random-walk step, i.e, add the current value to a random
    # value sampled from a Normal distribution centered at 0
    value = rand(context.rng, Normal(old_value, sampler.sigma))
    logp = Distributions.logpdf(dist, value)
    varinfo[var_id] = (value, logp)

    return value
end;

We need to define two step functions, one for the first step and the other for the following steps. In the first step we sample values from the prior distributions and in the following steps we sample with the random-walk proposal. The two functions are identified by the different arguments they take.

# The fist step: Sampling from the prior distributions
function AbstractMCMC.step(
    rng::Random.AbstractRNG, model::MiniModel, sampler::MHSampler; kwargs...
)
    vi = VarInfo()
    ctx = SamplingContext(rng, PriorSampler())
    model.f(vi, ctx, values(model.data)...)
    return vi, vi
end

# The following steps: Sampling with random-walk proposal
function AbstractMCMC.step(
    rng::Random.AbstractRNG,
    model::MiniModel,
    sampler::MHSampler,
    prev_state::VarInfo; # is just the old trace
    kwargs...,
)
    vi = prev_state
    new_vi = deepcopy(vi)
    ctx = SamplingContext(rng, sampler)
    model.f(new_vi, ctx, values(model.data)...)

    # Compute log acceptance probability
    # Since the proposal is symmetric the computation can be simplified
    logα = sum(values(new_vi.logps)) - sum(values(vi.logps))

    # Accept proposal with computed acceptance probability
    if -randexp(rng) < logα
        return new_vi, new_vi
    else
        return prev_state, prev_state
    end
end;

To make it easier to analyze the samples and compare them with results from Turing, additionally we define a version of AbstractMCMC.bundle_samples for our model and sampler that returns a MCMCChains.Chains object of samples.

function AbstractMCMC.bundle_samples(
    samples, model::MiniModel, ::MHSampler, ::Any, ::Type{Chains}; kwargs...
)
    # We get a vector of traces
    values = [sample.values for sample in samples]
    params = [key for key in keys(values[1]) if key  keys(model.data)]
    vals = reduce(hcat, [value[p] for value in values] for p in params)
    # Composing the `Chains` data-structure, of which analyzing infrastructure is provided
    chains = Chains(vals, params)
    return chains
end;

Let us check how our mini probabilistic programming language works. We define the probabilistic model:

@mini_model function m(x)
    a ~ Normal(0.5, 1)
    b ~ Normal(a, 2)
    x ~ Normal(b, 0.5)
    return nothing
end;

We perform inference with data x = 3.0:

sample(MiniModel(m, (x=3.0,)), MHSampler(), 1_000_000; chain_type=Chains, progress=false)
Chains MCMC chain (1000000×2×1 Array{Float64, 3}):

Iterations        = 1:1:1000000
Number of chains  = 1
Samples per chain = 1000000
parameters        = a, b

Summary Statistics
  parameters      mean       std      mcse      ess_bulk      ess_tail      rh ⋯
      Symbol   Float64   Float64   Float64       Float64       Float64   Float ⋯

           a    0.9743    0.8984    0.0032    80019.8038   122371.9588    1.00 ⋯
           b    2.8822    0.4880    0.0012   171497.9637   213751.8569    1.00 ⋯
                                                               2 columns omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           a   -0.7840    0.3672    0.9755    1.5776    2.7415
           b    1.9276    2.5526    2.8820    3.2125    3.8359

We compare these results with Turing.

using Turing
using PDMats

@model function turing_m(x)
    a ~ Normal(0.5, 1)
    b ~ Normal(a, 2)
    x ~ Normal(b, 0.5)
    return nothing
end

sample(turing_m(3.0), MH(ScalMat(2, 1.0)), 1_000_000, progress=false)
Chains MCMC chain (1000000×3×1 Array{Float64, 3}):

Iterations        = 1:1:1000000
Number of chains  = 1
Samples per chain = 1000000
Wall duration     = 31.21 seconds
Compute duration  = 31.21 seconds
parameters        = a, b
internals         = lp

Summary Statistics
  parameters      mean       std      mcse      ess_bulk      ess_tail      rh ⋯
      Symbol   Float64   Float64   Float64       Float64       Float64   Float ⋯

           a    0.9776    0.8985    0.0032    81117.2647   120572.6440    1.00 ⋯
           b    2.8798    0.4874    0.0012   171989.5545   210994.9131    1.00 ⋯
                                                               2 columns omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           a   -0.7854    0.3704    0.9766    1.5832    2.7445
           b    1.9252    2.5515    2.8794    3.2090    3.8343

As you can see, with our simple probabilistic programming language and custom samplers we get similar results as Turing.

Back to top