@mini_model function m(x)
~ Normal(0.5, 1)
a ~ Normal(a, 2)
b ~ Normal(b, 0.5)
x return nothing
end
A Mini Turing Implementation I: Compiler
In this tutorial we develop a very simple probabilistic programming language. The implementation is similar to DynamicPPL. This is intentional as we want to demonstrate some key ideas from Turing’s internal implementation.
To make things easy to understand and to implement we restrict our language to a very simple subset of the language that Turing actually supports. Defining an accurate syntax description is not our goal here, instead, we give a simple example and all similar programs should work.
Consider a probabilistic model defined by
\[ \begin{aligned} a &\sim \operatorname{Normal}(0.5, 1^2) \\ b &\sim \operatorname{Normal}(a, 2^2) \\ x &\sim \operatorname{Normal}(b, 0.5^2) \end{aligned} \]
We assume that x
is data, i.e., an observed variable. In our small language this model will be defined as
Specifically, we demand that
- all observed variables are arguments of the program,
- the model definition does not contain any control flow,
- all variables are scalars, and
- the function returns
nothing
.
First, we import some required packages:
using MacroTools, Distributions, Random, AbstractMCMC, MCMCChains
Before getting to the actual “compiler”, we first build the data structure for the program trace. A program trace for a probabilistic programming language needs to at least record the values of stochastic variables and their log-probabilities.
struct VarInfo{V,L}
::V
values::L
logpsend
VarInfo() = VarInfo(Dict{Symbol,Float64}(), Dict{Symbol,Float64}())
function Base.setindex!(varinfo::VarInfo, (value, logp), var_id)
= value
varinfo.values[var_id] = logp
varinfo.logps[var_id] return varinfo
end
Internally, our probabilistic programming language works with two main functions:
assume
for sampling unobserved variables and computing their log-probabilities, andobserve
for computing log-probabilities of observed variables (but not sampling them).
For different inference algorithms we may have to use different sampling procedures and different log-probability computations. For instance, in some cases we might want to sample all variables from their prior distributions and in other cases we might only want to compute the log-likelihood of the observations based on a given set of values for the unobserved variables. Thus depending on the inference algorithm we want to use different assume
and observe
implementations. We can achieve this by providing this context
information as a function argument to assume
and observe
.
Note: Although the context system in this tutorial is inspired by DynamicPPL, it is very simplistic. We expand this mini Turing example in the contexts tutorial with some more complexity, to illustrate how and why contexts are central to Turing’s design. For the full details one still needs to go to the actual source of DynamicPPL though.
Here we can see the implementation of a sampler that draws values of unobserved variables from the prior and computes the log-probability for every variable.
struct SamplingContext{S<:AbstractMCMC.AbstractSampler,R<:Random.AbstractRNG}
::R
rng::S
samplerend
struct PriorSampler <: AbstractMCMC.AbstractSampler end
function observe(context::SamplingContext, varinfo, dist, var_id, var_value)
= logpdf(dist, var_value)
logp = (var_value, logp)
varinfo[var_id] return nothing
end
function assume(context::SamplingContext{PriorSampler}, varinfo, dist, var_id)
= Random.rand(context.rng, dist)
sample = logpdf(dist, sample)
logp = (sample, logp)
varinfo[var_id] return sample
end;
Next we define the “compiler” for our simple programming language. The term compiler is actually a bit misleading here since its only purpose is to transform the function definition in the @mini_model
macro by
- adding the context information (
context
) and the tracing data structure (varinfo
) as additional arguments, and - replacing tildes with calls to
assume
andobserve
.
Afterwards, as usual the Julia compiler will just-in-time compile the model function when it is called.
The manipulation of Julia expressions is an advanced part of the Julia language. The Julia documentation provides an introduction to and more details about this so-called metaprogramming.
macro mini_model(expr)
return esc(mini_model(expr))
end
function mini_model(expr)
# Split the function definition into a dictionary with its name, arguments, body etc.
= MacroTools.splitdef(expr)
def
# Replace tildes in the function body with calls to `assume` or `observe`
:body] = MacroTools.postwalk(def[:body]) do sub_expr
def[if MacroTools.@capture(sub_expr, var_ ~ dist_)
if var in def[:args]
# If the variable is an argument of the model function, it is observed
return :($(observe)(context, varinfo, $dist, $(Meta.quot(var)), $var))
else
# Otherwise it is unobserved
return :($var = $(assume)(context, varinfo, $dist, $(Meta.quot(var))))
end
else
return sub_expr
end
end
# Add `context` and `varinfo` arguments to the model function
:args] = vcat(:varinfo, :context, def[:args])
def[
# Reassemble the function definition from its name, arguments, body etc.
return MacroTools.combinedef(def)
end;
For inference, we make use of the AbstractMCMC interface. It provides a default implementation of a sample
function for sampling a Markov chain. The default implementation already supports e.g. sampling of multiple chains in parallel, thinning of samples, or discarding initial samples.
The AbstractMCMC interface requires us to at least
- define a model that is a subtype of
AbstractMCMC.AbstractModel
, - define a sampler that is a subtype of
AbstractMCMC.AbstractSampler
, - implement
AbstractMCMC.step
for our model and sampler.
Thus here we define a MiniModel
model. In this model we store the model function and the observed data.
struct MiniModel{F,D} <: AbstractMCMC.AbstractModel
::F
f::D # a NamedTuple of all the data
dataend
In the Turing compiler, the model-specific DynamicPPL.Model
is constructed automatically when calling the model function. But for the sake of simplicity here we construct the model manually.
To illustrate probabilistic inference with our mini language we implement an extremely simplistic Random-Walk Metropolis-Hastings sampler. We hard-code the proposal step as part of the sampler and only allow normal distributions with zero mean and fixed standard deviation. The Metropolis-Hastings sampler in Turing is more flexible.
struct MHSampler{T<:Real} <: AbstractMCMC.AbstractSampler
::T
sigmaend
MHSampler() = MHSampler(1)
function assume(context::SamplingContext{<:MHSampler}, varinfo, dist, var_id)
= context.sampler
sampler = varinfo.values[var_id]
old_value
# propose a random-walk step, i.e, add the current value to a random
# value sampled from a Normal distribution centered at 0
= rand(context.rng, Normal(old_value, sampler.sigma))
value = Distributions.logpdf(dist, value)
logp = (value, logp)
varinfo[var_id]
return value
end;
We need to define two step
functions, one for the first step and the other for the following steps. In the first step we sample values from the prior distributions and in the following steps we sample with the random-walk proposal. The two functions are identified by the different arguments they take.
# The fist step: Sampling from the prior distributions
function AbstractMCMC.step(
::Random.AbstractRNG, model::MiniModel, sampler::MHSampler; kwargs...
rng
)= VarInfo()
vi = SamplingContext(rng, PriorSampler())
ctx f(vi, ctx, values(model.data)...)
model.return vi, vi
end
# The following steps: Sampling with random-walk proposal
function AbstractMCMC.step(
::Random.AbstractRNG,
rng::MiniModel,
model::MHSampler,
sampler::VarInfo; # is just the old trace
prev_state...,
kwargs
)= prev_state
vi = deepcopy(vi)
new_vi = SamplingContext(rng, sampler)
ctx f(new_vi, ctx, values(model.data)...)
model.
# Compute log acceptance probability
# Since the proposal is symmetric the computation can be simplified
= sum(values(new_vi.logps)) - sum(values(vi.logps))
logα
# Accept proposal with computed acceptance probability
if -randexp(rng) < logα
return new_vi, new_vi
else
return prev_state, prev_state
end
end;
To make it easier to analyze the samples and compare them with results from Turing, additionally we define a version of AbstractMCMC.bundle_samples
for our model and sampler that returns a MCMCChains.Chains
object of samples.
function AbstractMCMC.bundle_samples(
::MiniModel, ::MHSampler, ::Any, ::Type{Chains}; kwargs...
samples, model
)# We get a vector of traces
= [sample.values for sample in samples]
values = [key for key in keys(values[1]) if key ∉ keys(model.data)]
params = reduce(hcat, [value[p] for value in values] for p in params)
vals # Composing the `Chains` data-structure, of which analyzing infrastructure is provided
= Chains(vals, params)
chains return chains
end;
Let us check how our mini probabilistic programming language works. We define the probabilistic model:
@mini_model function m(x)
~ Normal(0.5, 1)
a ~ Normal(a, 2)
b ~ Normal(b, 0.5)
x return nothing
end;
We perform inference with data x = 3.0
:
sample(MiniModel(m, (x=3.0,)), MHSampler(), 1_000_000; chain_type=Chains, progress=false)
Chains MCMC chain (1000000×2×1 Array{Float64, 3}):
Iterations = 1:1:1000000
Number of chains = 1
Samples per chain = 1000000
parameters = a, b
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rh ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float ⋯
a 0.9743 0.8984 0.0032 80019.8038 122371.9588 1.00 ⋯
b 2.8822 0.4880 0.0012 171497.9637 213751.8569 1.00 ⋯
2 columns omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
a -0.7840 0.3672 0.9755 1.5776 2.7415
b 1.9276 2.5526 2.8820 3.2125 3.8359
We compare these results with Turing.
using Turing
using PDMats
@model function turing_m(x)
~ Normal(0.5, 1)
a ~ Normal(a, 2)
b ~ Normal(b, 0.5)
x return nothing
end
sample(turing_m(3.0), MH(ScalMat(2, 1.0)), 1_000_000, progress=false)
Chains MCMC chain (1000000×3×1 Array{Float64, 3}):
Iterations = 1:1:1000000
Number of chains = 1
Samples per chain = 1000000
Wall duration = 31.21 seconds
Compute duration = 31.21 seconds
parameters = a, b
internals = lp
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rh ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float ⋯
a 0.9776 0.8985 0.0032 81117.2647 120572.6440 1.00 ⋯
b 2.8798 0.4874 0.0012 171989.5545 210994.9131 1.00 ⋯
2 columns omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
a -0.7854 0.3704 0.9766 1.5832 2.7445
b 1.9252 2.5515 2.8794 3.2090 3.8343
As you can see, with our simple probabilistic programming language and custom samplers we get similar results as Turing.