Advanced Usage

How to Define a Customized Distribution

Turing.jl supports the use of distributions from the Distributions.jl package. By extension, it also supports the use of customized distributions by defining them as subtypes of Distribution type of the Distributions.jl package, as well as corresponding functions.

Below shows a workflow of how to define a customized distribution, using our own implementation of a simple Uniform distribution as a simple example.

1. Define the Distribution Type

First, define a type of the distribution, as a subtype of a corresponding distribution type in the Distributions.jl package.

struct CustomUniform <: ContinuousUnivariateDistribution end

2. Implement Sampling and Evaluation of the log-pdf

Second, define rand and logpdf, which will be used to run the model.

# sample in [0, 1]
Distributions.rand(rng::AbstractRNG, d::CustomUniform) = rand(rng)

# p(x) = 1 → logp(x) = 0
Distributions.logpdf(d::CustomUniform, x::Real) = zero(x)

3. Define Helper Functions

In most cases, it may be required to define some helper functions.

3.1 Domain Transformation

Certain samplers, such as HMC, require the domain of the priors to be unbounded. Therefore, to use our CustomUniform as a prior in a model we also need to define how to transform samples from [0, 1] to . To do this, we simply need to define the corresponding Bijector from Bijectors.jl, which is what Turing.jl uses internally to deal with constrained distributions.

To transform from [0, 1] to we can use the Logit bijector:

Bijectors.bijector(d::CustomUniform) = Logit(0.0, 1.0)

You’d do the exact same thing for ContinuousMultivariateDistribution and ContinuousMatrixDistribution. For example, Wishart defines a distribution over positive-definite matrices and so bijector returns a PDBijector when called with a Wishart distribution as an argument. For discrete distributions, there is no need to define a bijector; the Identity bijector is used by default.

Alternatively, for UnivariateDistribution we can define the minimum and maximum of the distribution

Distributions.minimum(d::CustomUniform) = 0.0
Distributions.maximum(d::CustomUniform) = 1.0

and Bijectors.jl will return a default Bijector called TruncatedBijector which makes use of minimum and maximum derive the correct transformation.

Internally, Turing basically does the following when it needs to convert a constrained distribution to an unconstrained distribution, e.g. when sampling using HMC:

dist = Gamma(2,3)
b = bijector(dist)
transformed_dist = transformed(dist, b) # results in distribution with transformed support + correction for logpdf
Bijectors.UnivariateTransformed{Distributions.Gamma{Float64}, Base.Fix1{typeof(broadcast), typeof(log)}}(
dist: Distributions.Gamma{Float64}(α=2.0, θ=3.0)
transform: Base.Fix1{typeof(broadcast), typeof(log)}(broadcast, log)
)

and then we can call rand and logpdf as usual, where

  • rand(transformed_dist) returns a sample in the unconstrained space, and
  • logpdf(transformed_dist, y) returns the log density of the original distribution, but with y living in the unconstrained space.

To read more about Bijectors.jl, check out the project README.

Update the accumulated log probability in the model definition

Turing accumulates log probabilities internally in an internal data structure that is accessible through the internal variable __varinfo__ inside of the model definition (see below for more details about model internals). However, since users should not have to deal with internal data structures, a macro Turing.@addlogprob! is provided that increases the accumulated log probability. For instance, this allows you to include arbitrary terms in the likelihood

using Turing

myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x)

@model function demo(x)
    μ ~ Normal()
    Turing.@addlogprob! myloglikelihood(x, μ)
end
demo (generic function with 2 methods)

and to reject samples:

using Turing
using LinearAlgebra

@model function demo(x)
    m ~ MvNormal(zero(x), I)
    if dot(m, x) < 0
        Turing.@addlogprob! -Inf
        # Exit the model evaluation early
        return nothing
    end

    x ~ MvNormal(m, I)
    return nothing
end
demo (generic function with 2 methods)

Note that @addlogprob! always increases the accumulated log probability, regardless of the provided sampling context. For instance, if you do not want to apply Turing.@addlogprob! when evaluating the prior of your model but only when computing the log likelihood and the log joint probability, then you should check the type of the internal variable __context_ such as

if DynamicPPL.leafcontext(__context__) !== Turing.PriorContext()
    Turing.@addlogprob! myloglikelihood(x, μ)
end

Model Internals

The @model macro accepts a function definition and rewrites it such that call of the function generates a Model struct for use by the sampler. Models can be constructed by hand without the use of a macro. Taking the gdemo model as an example, the macro-based definition

using Turing

@model function gdemo(x)
    # Set priors.
~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s²))

    # Observe each value of x.
    @. x ~ Normal(m, sqrt(s²))
end

model = gdemo([1.5, 2.0])
DynamicPPL.Model{typeof(gdemo), (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}, DynamicPPL.DefaultContext}(Main.Notebook.gdemo, (x = [1.5, 2.0],), NamedTuple(), DynamicPPL.DefaultContext())

can be implemented also (a bit less generally) with the macro-free version

using Turing

# Create the model function.
function gdemo(model, varinfo, context, x)
    # Assume s² has an InverseGamma distribution.
    s², varinfo = DynamicPPL.tilde_assume!!(
        context, InverseGamma(2, 3), Turing.@varname(s²), varinfo
    )

    # Assume m has a Normal distribution.
    m, varinfo = DynamicPPL.tilde_assume!!(
        context, Normal(0, sqrt(s²)), Turing.@varname(m), varinfo
    )

    # Observe each value of x[i] according to a Normal distribution.
    return DynamicPPL.dot_tilde_observe!!(
        context, Normal(m, sqrt(s²)), x, Turing.@varname(x), varinfo
    )
end
gdemo(x) = Turing.Model(gdemo, (; x))

# Instantiate a Model object with our data variables.
model = gdemo([1.5, 2.0])
DynamicPPL.Model{typeof(gdemo), (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}, DynamicPPL.DefaultContext}(Main.Notebook.gdemo, (x = [1.5, 2.0],), NamedTuple(), DynamicPPL.DefaultContext())

Reparametrization and generated_quantities

Often, the most natural parameterization for a model is not the most computationally feasible. Consider the following (efficiently reparametrized) implementation of Neal’s funnel (Neal, 2003):

@model function Neal()
    # Raw draws
    y_raw ~ Normal(0, 1)
    x_raw ~ arraydist([Normal(0, 1) for i in 1:9])

    # Transform:
    y = 3 * y_raw
    x = exp.(y ./ 2) .* x_raw

    # Return:
    return [x; y]
end

In this case, the random variables exposed in the chain (x_raw, y_raw) are not in a helpful form — what we’re after is the deterministically transformed variables x, y.

More generally, there are often quantities in our models that we might be interested in viewing, but which are not explicitly present in our chain.

We can generate draws from these variables — in this case, x, y — by adding them as a return statement to the model, and then calling generated_quantities(model, chain). Calling this function outputs an array of values specified in the return statement of the model.

For example, in the above reparametrization, we sample from our model:

chain = sample(Neal(), NUTS(), 1000)

and then call:

generated_quantities(Neal(), chain)

to return an array for each posterior sample containing x1, x2, ... x9, y.

In this case, it might be useful to reorganize our output into a matrix for plotting:

reparam_chain = reduce(hcat, generated_quantities(Neal(), chain))'

Where we can recover a vector of our samples as follows:

x1_samples = reparam_chain[:, 1]
y_samples = reparam_chain[:, 10]

Task Copying

Turing copies Julia tasks to deliver efficient inference algorithms, but it also provides alternative slower implementation as a fallback. Task copying is enabled by default. Task copying requires us to use the TapedTask facility which is provided by Libtask to create tasks.

Back to top