import MacroTools, Random, AbstractMCMC
using Distributions: Normal, logpdf
using MCMCChains: Chains
using AbstractMCMC: sample
struct VarInfo{V,L}
::V
values::L
logpsend
VarInfo() = VarInfo(Dict{Symbol,Float64}(), Dict{Symbol,Float64}())
function Base.setindex!(varinfo::VarInfo, (value, logp), var_id)
= value
varinfo.values[var_id] = logp
varinfo.logps[var_id] return varinfo
end
A Mini Turing Implementation II: Contexts
In the Mini Turing tutorial we developed a miniature version of the Turing language, to illustrate its core design. A passing mention was made of contexts. In this tutorial we develop that aspect of our mini Turing language further to demonstrate how and why contexts are an important part of Turing’s design.
Mini Turing expanded, now with more contexts
If you haven’t read Mini Turing yet, you should do that first. We start by repeating verbatim much of the code from there. Define the type for holding values for variables:
Define the macro that expands ~
expressions to calls to assume
and observe
:
# Methods will be defined for these later.
function assume end
function observe end
macro mini_model(expr)
return esc(mini_model(expr))
end
function mini_model(expr)
# Split the function definition into a dictionary with its name, arguments, body etc.
= MacroTools.splitdef(expr)
def
# Replace tildes in the function body with calls to `assume` or `observe`
:body] = MacroTools.postwalk(def[:body]) do sub_expr
def[if MacroTools.@capture(sub_expr, var_ ~ dist_)
if var in def[:args]
# If the variable is an argument of the model function, it is observed
return :($(observe)(context, varinfo, $dist, $(Meta.quot(var)), $var))
else
# Otherwise it is unobserved
return :($var = $(assume)(context, varinfo, $dist, $(Meta.quot(var))))
end
else
return sub_expr
end
end
# Add `context` and `varinfo` arguments to the model function
:args] = vcat(:varinfo, :context, def[:args])
def[
# Reassemble the function definition from its name, arguments, body etc.
return MacroTools.combinedef(def)
end
struct MiniModel{F,D} <: AbstractMCMC.AbstractModel
::F
f::D # a NamedTuple of all the data
dataend
Define an example model:
@mini_model function m(x)
~ Normal(0.5, 1)
a ~ Normal(a, 2)
b ~ Normal(b, 0.5)
x return nothing
end;
= MiniModel(m, (x=3.0,)) mini_m
MiniModel{typeof(m), @NamedTuple{x::Float64}}(m, (x = 3.0,))
Previously in the mini Turing case, at this point we defined SamplingContext
, a structure that holds a random number generator and a sampler, and gets passed to observe
and assume
. We then used it to implement a simple Metropolis-Hastings sampler.
The notion of a context may have seemed overly complicated just to implement the sampler, but there are other things we may want to do with a model than sample from the posterior. Having the context passing in place lets us do that without having to touch the above macro at all. For instance, let’s say we want to evaluate the log joint probability of the model for a given set of data and parameters. Using a new context type we can use the previously defined model
function, but change its behavior by changing what the observe
and assume
functions do.
struct JointContext end
function observe(context::JointContext, varinfo, dist, var_id, var_value)
= logpdf(dist, var_value)
logp = (var_value, logp)
varinfo[var_id] return nothing
end
function assume(context::JointContext, varinfo, dist, var_id)
if !haskey(varinfo.values, var_id)
error("Can't evaluate the log probability if the variable $(var_id) is not set.")
end
= varinfo.values[var_id]
var_value = logpdf(dist, var_value)
logp = (var_value, logp)
varinfo[var_id] return var_value
end
function logjoint(model, parameter_values::NamedTuple)
= VarInfo()
vi for (var_id, value) in pairs(parameter_values)
# Set the log prob to NaN for now. These will get overwritten when model.f is
# called with JointContext.
= (value, NaN)
vi[var_id] end
f(vi, JointContext(), values(model.data)...)
model.return sum(values(vi.logps))
end
logjoint(mini_m, (a=0.5, b=1.0))
-10.788065599614018
When using the JointContext
no sampling whatsoever happens in calling mini_m
. Rather only the log probability of each given variable value is evaluated. logjoint
then sums these results to get the total log joint probability.
We can similarly define a context for evaluating the log prior probability:
struct PriorContext end
function observe(context::PriorContext, varinfo, dist, var_id, var_value)
# Since we are evaluating the prior, the log probability of all the observations
# is set to 0. This has the effect of ignoring the likelihood.
= (var_value, 0.0)
varinfo[var_id] return nothing
end
function assume(context::PriorContext, varinfo, dist, var_id)
if !haskey(varinfo.values, var_id)
error("Can't evaluate the log probability if the variable $(var_id) is not set.")
end
= varinfo.values[var_id]
var_value = logpdf(dist, var_value)
logp = (var_value, logp)
varinfo[var_id] return var_value
end
function logprior(model, parameter_values::NamedTuple)
= VarInfo()
vi for (var_id, value) in pairs(parameter_values)
= (value, NaN)
vi[var_id] end
f(vi, PriorContext(), values(model.data)...)
model.return sum(values(vi.logps))
end
logprior(mini_m, (a=0.5, b=1.0))
-2.5622742469692907
Notice that the definition of assume(context::PriorContext, args...)
is identical to the one for JointContext
, and logprior
and logjoint
are also identical except for the context type they create. There’s clearly an opportunity here for some refactoring using abstract types, but that’s outside the scope of this tutorial. Rather, the point here is to demonstrate that we can extract different sorts of things from our model by defining different context types, and specialising observe
and assume
for them.
Contexts within contexts
Let’s use the above two contexts to provide a slightly more general definition of the SamplingContext
and the Metropolis-Hastings sampler we wrote in the mini Turing tutorial.
struct SamplingContext{S<:AbstractMCMC.AbstractSampler,R<:Random.AbstractRNG}
::R
rng::S
sampler::Union{PriorContext, JointContext}
subcontextend
The new aspect here is the subcontext
field. Note that this is a context within a context! The idea is that we don’t need to hard code how the MCMC sampler evaluates the log probability, but rather can pass that work onto the subcontext. This way the same sampler can be used to sample from either the joint or the prior distribution.
The methods for SamplingContext
are largely as in the our earlier mini Turing case, except they now pass some of the work onto the subcontext:
function observe(context::SamplingContext, args...)
# Sampling doesn't affect the observed values, so nothing to do here other than pass to
# the subcontext.
return observe(context.subcontext, args...)
end
struct PriorSampler <: AbstractMCMC.AbstractSampler end
function assume(context::SamplingContext{PriorSampler}, varinfo, dist, var_id)
= Random.rand(context.rng, dist)
sample = (sample, NaN)
varinfo[var_id] # Once the value has been sampled, let the subcontext handle evaluating the log
# probability.
return assume(context.subcontext, varinfo, dist, var_id)
end;
# The subcontext field of the MHSampler determines which distribution this sampler
# samples from.
struct MHSampler{D, T<:Real} <: AbstractMCMC.AbstractSampler
::T
sigma::D
subcontextend
MHSampler(subcontext) = MHSampler(1, subcontext)
function assume(context::SamplingContext{<:MHSampler}, varinfo, dist, var_id)
= context.sampler
sampler = varinfo.values[var_id]
old_value
# propose a random-walk step, i.e, add the current value to a random
# value sampled from a Normal distribution centered at 0
= rand(context.rng, Normal(old_value, sampler.sigma))
value = (value, NaN)
varinfo[var_id] # Once the value has been sampled, let the subcontext handle evaluating the log
# probability.
return assume(context.subcontext, varinfo, dist, var_id)
end;
# The following three methods are identical to before, except for passing
# `sampler.subcontext` to the context SamplingContext.
function AbstractMCMC.step(
::Random.AbstractRNG, model::MiniModel, sampler::MHSampler; kwargs...
rng
)= VarInfo()
vi = SamplingContext(rng, PriorSampler(), sampler.subcontext)
ctx f(vi, ctx, values(model.data)...)
model.return vi, vi
end
function AbstractMCMC.step(
::Random.AbstractRNG,
rng::MiniModel,
model::MHSampler,
sampler::VarInfo; # is just the old trace
prev_state...,
kwargs
)= prev_state
vi = deepcopy(vi)
new_vi = SamplingContext(rng, sampler, sampler.subcontext)
ctx f(new_vi, ctx, values(model.data)...)
model.
# Compute log acceptance probability
# Since the proposal is symmetric the computation can be simplified
= sum(values(new_vi.logps)) - sum(values(vi.logps))
logα
# Accept proposal with computed acceptance probability
if -Random.randexp(rng) < logα
return new_vi, new_vi
else
return prev_state, prev_state
end
end;
function AbstractMCMC.bundle_samples(
::MiniModel, ::MHSampler, ::Any, ::Type{Chains}; kwargs...
samples, model
)# We get a vector of traces
= [sample.values for sample in samples]
values = [key for key in keys(values[1]) if key ∉ keys(model.data)]
params = reduce(hcat, [value[p] for value in values] for p in params)
vals # Composing the `Chains` data-structure, of which analyzing infrastructure is provided
= Chains(vals, params)
chains return chains
end;
We can use this to sample from the joint distribution just like before:
sample(MiniModel(m, (x=3.0,)), MHSampler(JointContext()), 1_000_000; chain_type=Chains, progress=false)
Chains MCMC chain (1000000×2×1 Array{Float64, 3}):
Iterations = 1:1:1000000
Number of chains = 1
Samples per chain = 1000000
parameters = a, b
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rh ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float ⋯
a 0.9713 0.9014 0.0032 80387.2713 119459.5514 1.00 ⋯
b 2.8796 0.4892 0.0012 172170.0034 211488.7837 1.00 ⋯
2 columns omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
a -0.8075 0.3629 0.9732 1.5803 2.7342
b 1.9202 2.5498 2.8803 3.2090 3.8386
or we can choose to sample from the prior instead
sample(MiniModel(m, (x=3.0,)), MHSampler(PriorContext()), 1_000_000; chain_type=Chains, progress=false)
Chains MCMC chain (1000000×2×1 Array{Float64, 3}):
Iterations = 1:1:1000000
Number of chains = 1
Samples per chain = 1000000
parameters = a, b
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rha ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float6 ⋯
a 0.4977 1.0017 0.0040 64312.4640 127157.3129 1.000 ⋯
b 0.4929 2.2363 0.0135 27644.6544 51960.2825 1.000 ⋯
2 columns omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
a -1.4669 -0.1768 0.4980 1.1717 2.4600
b -3.8853 -1.0153 0.4988 2.0051 4.8568
Of course, using an MCMC algorithm to sample from the prior is unnecessary and silly (PriorSampler
exists, after all), but the point is to illustrate the flexibility of the context system. We could, for instance, use the same setup to implement an Approximate Bayesian Computation (ABC) algorithm.
The use of contexts also goes far beyond just evaluating log probabilities and sampling. Some examples from Turing are
FixedContext
, which fixes some variables to given values and removes them completely from the evaluation of any log probabilities. They power theTuring.fix
andTuring.unfix
functions.ConditionContext
conditions the model on fixed values for some parameters. They are used byTuring.condition
andTuring.uncondition
, i.e. themodel | (parameter=value,)
syntax. The difference betweenfix
andcondition
is whether the log probability for the corresponding variable is included in the overall log density.PriorExtractorContext
collects information about what the prior distribution of each variable is.PrefixContext
adds prefixes to variable names, allowing models to be used within other models without variable name collisions.PointwiseLikelihoodContext
records the log likelihood of each individual variable.DebugContext
collects useful debugging information while executing the model.
All of the above are what Turing calls parent contexts, which is to say that they all keep a subcontext just like our above SamplingContext
did. Their implementations of assume
and observe
call the implementation of the subcontext once they are done doing their own work of fixing/conditioning/prefixing/etc. Contexts are often chained, so that e.g. a DebugContext
may wrap within it a PrefixContext
, which may in turn wrap a ConditionContext
, etc. The only contexts that don’t have a subcontext in the Turing are the ones for evaluating the prior, likelihood, and joint distributions. These are called leaf contexts.
The above version of mini Turing is still much simpler than the full Turing language, but the principles of how contexts are used are the same.