A Mini Turing Implementation II: Contexts

In the Mini Turing tutorial we developed a miniature version of the Turing language, to illustrate its core design. A passing mention was made of contexts. In this tutorial we develop that aspect of our mini Turing language further to demonstrate how and why contexts are an important part of Turing’s design.

Mini Turing expanded, now with more contexts

If you haven’t read Mini Turing yet, you should do that first. We start by repeating verbatim much of the code from there. Define the type for holding values for variables:

import MacroTools, Random, AbstractMCMC
using Distributions: Normal, logpdf
using MCMCChains: Chains
using AbstractMCMC: sample

struct VarInfo{V,L}
    values::V
    logps::L
end

VarInfo() = VarInfo(Dict{Symbol,Float64}(), Dict{Symbol,Float64}())

function Base.setindex!(varinfo::VarInfo, (value, logp), var_id)
    varinfo.values[var_id] = value
    varinfo.logps[var_id] = logp
    return varinfo
end

Define the macro that expands ~ expressions to calls to assume and observe:

# Methods will be defined for these later.
function assume end
function observe end

macro mini_model(expr)
    return esc(mini_model(expr))
end

function mini_model(expr)
    # Split the function definition into a dictionary with its name, arguments, body etc.
    def = MacroTools.splitdef(expr)

    # Replace tildes in the function body with calls to `assume` or `observe`
    def[:body] = MacroTools.postwalk(def[:body]) do sub_expr
        if MacroTools.@capture(sub_expr, var_ ~ dist_)
            if var in def[:args]
                # If the variable is an argument of the model function, it is observed
                return :($(observe)(context, varinfo, $dist, $(Meta.quot(var)), $var))
            else
                # Otherwise it is unobserved
                return :($var = $(assume)(context, varinfo, $dist, $(Meta.quot(var))))
            end
        else
            return sub_expr
        end
    end

    # Add `context` and `varinfo` arguments to the model function
    def[:args] = vcat(:varinfo, :context, def[:args])

    # Reassemble the function definition from its name, arguments, body etc.
    return MacroTools.combinedef(def)
end


struct MiniModel{F,D} <: AbstractMCMC.AbstractModel
    f::F
    data::D # a NamedTuple of all the data
end

Define an example model:

@mini_model function m(x)
    a ~ Normal(0.5, 1)
    b ~ Normal(a, 2)
    x ~ Normal(b, 0.5)
    return nothing
end;

mini_m = MiniModel(m, (x=3.0,))
MiniModel{typeof(m), @NamedTuple{x::Float64}}(Main.Notebook.m, (x = 3.0,))

Previously in the mini Turing case, at this point we defined SamplingContext, a structure that holds a random number generator and a sampler, and gets passed to observe and assume. We then used it to implement a simple Metropolis-Hastings sampler.

The notion of a context may have seemed overly complicated just to implement the sampler, but there are other things we may want to do with a model than sample from the posterior. Having the context passing in place lets us do that without having to touch the above macro at all. For instance, let’s say we want to evaluate the log joint probability of the model for a given set of data and parameters. Using a new context type we can use the previously defined model function, but change its behavior by changing what the observe and assume functions do.

struct JointContext end

function observe(context::JointContext, varinfo, dist, var_id, var_value)
    logp = logpdf(dist, var_value)
    varinfo[var_id] = (var_value, logp)
    return nothing
end

function assume(context::JointContext, varinfo, dist, var_id)
    if !haskey(varinfo.values, var_id)
        error("Can't evaluate the log probability if the variable $(var_id) is not set.")
    end
    var_value = varinfo.values[var_id]
    logp = logpdf(dist, var_value)
    varinfo[var_id] = (var_value, logp)
    return var_value
end

function logjoint(model, parameter_values::NamedTuple)
    vi = VarInfo()
    for (var_id, value) in pairs(parameter_values)
        # Set the log prob to NaN for now. These will get overwritten when model.f is
        # called with JointContext.
        vi[var_id] = (value, NaN)
    end
    model.f(vi, JointContext(), values(model.data)...)
    return sum(values(vi.logps))
end

logjoint(mini_m, (a=0.5, b=1.0))
-10.788065599614018

When using the JointContext no sampling whatsoever happens in calling mini_m. Rather only the log probability of each given variable value is evaluated. logjoint then sums these results to get the total log joint probability.

We can similarly define a context for evaluating the log prior probability:

struct PriorContext end

function observe(context::PriorContext, varinfo, dist, var_id, var_value)
    # Since we are evaluating the prior, the log probability of all the observations
    # is set to 0. This has the effect of ignoring the likelihood.
    varinfo[var_id] = (var_value, 0.0)
    return nothing
end

function assume(context::PriorContext, varinfo, dist, var_id)
    if !haskey(varinfo.values, var_id)
        error("Can't evaluate the log probability if the variable $(var_id) is not set.")
    end
    var_value = varinfo.values[var_id]
    logp = logpdf(dist, var_value)
    varinfo[var_id] = (var_value, logp)
    return var_value
end

function logprior(model, parameter_values::NamedTuple)
    vi = VarInfo()
    for (var_id, value) in pairs(parameter_values)
        vi[var_id] = (value, NaN)
    end
    model.f(vi, PriorContext(), values(model.data)...)
    return sum(values(vi.logps))
end

logprior(mini_m, (a=0.5, b=1.0))
-2.5622742469692907

Notice that the definition of assume(context::PriorContext, args...) is identical to the one for JointContext, and logprior and logjoint are also identical except for the context type they create. There’s clearly an opportunity here for some refactoring using abstract types, but that’s outside the scope of this tutorial. Rather, the point here is to demonstrate that we can extract different sorts of things from our model by defining different context types, and specialising observe and assume for them.

Contexts within contexts

Let’s use the above two contexts to provide a slightly more general definition of the SamplingContext and the Metropolis-Hastings sampler we wrote in the mini Turing tutorial.

struct SamplingContext{S<:AbstractMCMC.AbstractSampler,R<:Random.AbstractRNG}
    rng::R
    sampler::S
    subcontext::Union{PriorContext, JointContext}
end

The new aspect here is the subcontext field. Note that this is a context within a context! The idea is that we don’t need to hard code how the MCMC sampler evaluates the log probability, but rather can pass that work onto the subcontext. This way the same sampler can be used to sample from either the joint or the prior distribution.

The methods for SamplingContext are largely as in the our earlier mini Turing case, except they now pass some of the work onto the subcontext:

function observe(context::SamplingContext, args...)
    # Sampling doesn't affect the observed values, so nothing to do here other than pass to
    # the subcontext.
    return observe(context.subcontext, args...)
end

struct PriorSampler <: AbstractMCMC.AbstractSampler end

function assume(context::SamplingContext{PriorSampler}, varinfo, dist, var_id)
    sample = Random.rand(context.rng, dist)
    varinfo[var_id] = (sample, NaN)
    # Once the value has been sampled, let the subcontext handle evaluating the log
    # probability.
    return assume(context.subcontext, varinfo, dist, var_id)
end;

# The subcontext field of the MHSampler determines which distribution this sampler
# samples from.
struct MHSampler{D, T<:Real} <: AbstractMCMC.AbstractSampler
    sigma::T
    subcontext::D
end

MHSampler(subcontext) = MHSampler(1, subcontext)

function assume(context::SamplingContext{<:MHSampler}, varinfo, dist, var_id)
    sampler = context.sampler
    old_value = varinfo.values[var_id]

    # propose a random-walk step, i.e, add the current value to a random 
    # value sampled from a Normal distribution centered at 0
    value = rand(context.rng, Normal(old_value, sampler.sigma))
    varinfo[var_id] = (value, NaN)
    # Once the value has been sampled, let the subcontext handle evaluating the log
    # probability.
    return assume(context.subcontext, varinfo, dist, var_id)
end;

# The following three methods are identical to before, except for passing
# `sampler.subcontext` to the context SamplingContext.
function AbstractMCMC.step(
    rng::Random.AbstractRNG, model::MiniModel, sampler::MHSampler; kwargs...
)
    vi = VarInfo()
    ctx = SamplingContext(rng, PriorSampler(), sampler.subcontext)
    model.f(vi, ctx, values(model.data)...)
    return vi, vi
end

function AbstractMCMC.step(
    rng::Random.AbstractRNG,
    model::MiniModel,
    sampler::MHSampler,
    prev_state::VarInfo; # is just the old trace
    kwargs...,
)
    vi = prev_state
    new_vi = deepcopy(vi)
    ctx = SamplingContext(rng, sampler, sampler.subcontext)
    model.f(new_vi, ctx, values(model.data)...)

    # Compute log acceptance probability
    # Since the proposal is symmetric the computation can be simplified
    logα = sum(values(new_vi.logps)) - sum(values(vi.logps))

    # Accept proposal with computed acceptance probability
    if -Random.randexp(rng) < logα
        return new_vi, new_vi
    else
        return prev_state, prev_state
    end
end;

function AbstractMCMC.bundle_samples(
    samples, model::MiniModel, ::MHSampler, ::Any, ::Type{Chains}; kwargs...
)
    # We get a vector of traces
    values = [sample.values for sample in samples]
    params = [key for key in keys(values[1]) if key  keys(model.data)]
    vals = reduce(hcat, [value[p] for value in values] for p in params)
    # Composing the `Chains` data-structure, of which analyzing infrastructure is provided
    chains = Chains(vals, params)
    return chains
end;

We can use this to sample from the joint distribution just like before:

sample(MiniModel(m, (x=3.0,)), MHSampler(JointContext()), 1_000_000; chain_type=Chains, progress=false)
Chains MCMC chain (1000000×2×1 Array{Float64, 3}):

Iterations        = 1:1:1000000
Number of chains  = 1
Samples per chain = 1000000
parameters        = a, b

Summary Statistics
  parameters      mean       std      mcse      ess_bulk      ess_tail      rh     Symbol   Float64   Float64   Float64       Float64       Float64   Float ⋯

           a    0.9822    0.9002    0.0032    81339.5848   123233.4231    1.00 ⋯
           b    2.8818    0.4881    0.0012   172867.2614   215320.0136    1.00 ⋯
                                                               2 columns omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           a   -0.7823    0.3742    0.9811    1.5875    2.7505
           b    1.9286    2.5533    2.8815    3.2093    3.8436

or we can choose to sample from the prior instead

sample(MiniModel(m, (x=3.0,)), MHSampler(PriorContext()), 1_000_000; chain_type=Chains, progress=false)
Chains MCMC chain (1000000×2×1 Array{Float64, 3}):

Iterations        = 1:1:1000000
Number of chains  = 1
Samples per chain = 1000000
parameters        = a, b

Summary Statistics
  parameters      mean       std      mcse     ess_bulk      ess_tail      rha     Symbol   Float64   Float64   Float64      Float64       Float64   Float6 ⋯

           a    0.5012    1.0000    0.0039   65206.1389   122070.5596    1.000 ⋯
           b    0.5191    2.2306    0.0136   26900.9154    51488.9717    1.000 ⋯
                                                               2 columns omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           a   -1.4693   -0.1714    0.5014    1.1733    2.4600
           b   -3.8564   -0.9838    0.5215    2.0251    4.8815

Of course, using an MCMC algorithm to sample from the prior is unnecessary and silly (PriorSampler exists, after all), but the point is to illustrate the flexibility of the context system. We could, for instance, use the same setup to implement an Approximate Bayesian Computation (ABC) algorithm.

The use of contexts also goes far beyond just evaluating log probabilities and sampling. Some examples from Turing are

  • FixedContext, which fixes some variables to given values and removes them completely from the evaluation of any log probabilities. They power the Turing.fix and Turing.unfix functions.

  • ConditionContext conditions the model on fixed values for some parameters. They are used by Turing.condition and Turing.uncondition, i.e. the model | (parameter=value,) syntax. The difference between fix and condition is whether the log probability for the corresponding variable is included in the overall log density.

  • PriorExtractorContext collects information about what the prior distribution of each variable is.

  • PrefixContext adds prefixes to variable names, allowing models to be used within other models without variable name collisions.

  • PointwiseLikelihoodContext records the log likelihood of each individual variable.

  • DebugContext collects useful debugging information while executing the model.

All of the above are what Turing calls parent contexts, which is to say that they all keep a subcontext just like our above SamplingContext did. Their implementations of assume and observe call the implementation of the subcontext once they are done doing their own work of fixing/conditioning/prefixing/etc. Contexts are often chained, so that e.g. a DebugContext may wrap within it a PrefixContext, which may in turn wrap a ConditionContext, etc. The only contexts that don’t have a subcontext in the Turing are the ones for evaluating the prior, likelihood, and joint distributions. These are called leaf contexts.

The above version of mini Turing is still much simpler than the full Turing language, but the principles of how contexts are used are the same.

Back to top