Custom Distributions

Turing.jl supports the use of distributions from the Distributions.jl package. By extension, it also supports the use of customised distributions by defining them as subtypes of the Distribution type in the Distributions.jl package, as well as corresponding functions.

This page shows a workflow of how to define a customised distribution, using our own implementation of a simple Uniform distribution as a simple example.

using Distributions, Turing, Random

Distribution type

First, define a type of the distribution, as a subtype of a corresponding distribution type in the Distributions.jl package.

struct CustomUniform <: ContinuousUnivariateDistribution end

Sampling and log-probability evaluation

Second, implement the rand and logpdf functions for your new distribution, which will be used to run the model.

# sample in [0, 1]
Base.rand(rng::AbstractRNG, ::CustomUniform) = rand(rng)

# p(x) = 1 → log[p(x)] = 0
Distributions.logpdf(::CustomUniform, x::Real) = zero(x)

Bijectors

Once you have defined the above, you should be able to use your distribution in a Turing model and sample with Prior().

@model function demo()
    x ~ CustomUniform()
end

mean(sample(demo(), Prior(), 100))
Sampling   0%|                                          |  ETA: N/A
Sampling   1%|▍                                         |  ETA: 0:02:26
Sampling   2%|▉                                         |  ETA: 0:01:14
Sampling   3%|█▎                                        |  ETA: 0:00:49
Sampling   4%|█▋                                        |  ETA: 0:00:36
Sampling   5%|██▏                                       |  ETA: 0:00:29
Sampling   6%|██▌                                       |  ETA: 0:00:24
Sampling   7%|███                                       |  ETA: 0:00:20
Sampling   8%|███▍                                      |  ETA: 0:00:17
Sampling   9%|███▊                                      |  ETA: 0:00:15
Sampling  10%|████▎                                     |  ETA: 0:00:14
Sampling  11%|████▋                                     |  ETA: 0:00:12
Sampling  12%|█████                                     |  ETA: 0:00:11
Sampling  13%|█████▌                                    |  ETA: 0:00:10
Sampling  14%|█████▉                                    |  ETA: 0:00:09
Sampling  15%|██████▎                                   |  ETA: 0:00:09
Sampling  16%|██████▊                                   |  ETA: 0:00:08
Sampling  17%|███████▏                                  |  ETA: 0:00:07
Sampling  18%|███████▌                                  |  ETA: 0:00:07
Sampling  19%|████████                                  |  ETA: 0:00:06
Sampling  20%|████████▍                                 |  ETA: 0:00:06
Sampling  21%|████████▉                                 |  ETA: 0:00:06
Sampling  22%|█████████▎                                |  ETA: 0:00:05
Sampling  23%|█████████▋                                |  ETA: 0:00:05
Sampling  24%|██████████▏                               |  ETA: 0:00:05
Sampling  25%|██████████▌                               |  ETA: 0:00:05
Sampling  26%|██████████▉                               |  ETA: 0:00:04
Sampling  27%|███████████▍                              |  ETA: 0:00:04
Sampling  28%|███████████▊                              |  ETA: 0:00:04
Sampling  29%|████████████▏                             |  ETA: 0:00:04
Sampling  30%|████████████▋                             |  ETA: 0:00:04
Sampling  31%|█████████████                             |  ETA: 0:00:03
Sampling  32%|█████████████▌                            |  ETA: 0:00:03
Sampling  33%|█████████████▉                            |  ETA: 0:00:03
Sampling  34%|██████████████▎                           |  ETA: 0:00:03
Sampling  35%|██████████████▊                           |  ETA: 0:00:03
Sampling  36%|███████████████▏                          |  ETA: 0:00:03
Sampling  37%|███████████████▌                          |  ETA: 0:00:03
Sampling  38%|████████████████                          |  ETA: 0:00:02
Sampling  39%|████████████████▍                         |  ETA: 0:00:02
Sampling  40%|████████████████▊                         |  ETA: 0:00:02
Sampling  41%|█████████████████▎                        |  ETA: 0:00:02
Sampling  42%|█████████████████▋                        |  ETA: 0:00:02
Sampling  43%|██████████████████                        |  ETA: 0:00:02
Sampling  44%|██████████████████▌                       |  ETA: 0:00:02
Sampling  45%|██████████████████▉                       |  ETA: 0:00:02
Sampling  46%|███████████████████▍                      |  ETA: 0:00:02
Sampling  47%|███████████████████▊                      |  ETA: 0:00:02
Sampling  48%|████████████████████▏                     |  ETA: 0:00:02
Sampling  49%|████████████████████▋                     |  ETA: 0:00:02
Sampling  50%|█████████████████████                     |  ETA: 0:00:02
Sampling  51%|█████████████████████▍                    |  ETA: 0:00:01
Sampling  52%|█████████████████████▉                    |  ETA: 0:00:01
Sampling  53%|██████████████████████▎                   |  ETA: 0:00:01
Sampling  54%|██████████████████████▋                   |  ETA: 0:00:01
Sampling  55%|███████████████████████▏                  |  ETA: 0:00:01
Sampling  56%|███████████████████████▌                  |  ETA: 0:00:01
Sampling  57%|████████████████████████                  |  ETA: 0:00:01
Sampling  58%|████████████████████████▍                 |  ETA: 0:00:01
Sampling  59%|████████████████████████▊                 |  ETA: 0:00:01
Sampling  60%|█████████████████████████▎                |  ETA: 0:00:01
Sampling  61%|█████████████████████████▋                |  ETA: 0:00:01
Sampling  62%|██████████████████████████                |  ETA: 0:00:01
Sampling  63%|██████████████████████████▌               |  ETA: 0:00:01
Sampling  64%|██████████████████████████▉               |  ETA: 0:00:01
Sampling  65%|███████████████████████████▎              |  ETA: 0:00:01
Sampling  66%|███████████████████████████▊              |  ETA: 0:00:01
Sampling  67%|████████████████████████████▏             |  ETA: 0:00:01
Sampling  68%|████████████████████████████▌             |  ETA: 0:00:01
Sampling  69%|█████████████████████████████             |  ETA: 0:00:01
Sampling  70%|█████████████████████████████▍            |  ETA: 0:00:01
Sampling  71%|█████████████████████████████▉            |  ETA: 0:00:01
Sampling  72%|██████████████████████████████▎           |  ETA: 0:00:01
Sampling  73%|██████████████████████████████▋           |  ETA: 0:00:01
Sampling  74%|███████████████████████████████▏          |  ETA: 0:00:01
Sampling  75%|███████████████████████████████▌          |  ETA: 0:00:01
Sampling  76%|███████████████████████████████▉          |  ETA: 0:00:00
Sampling  77%|████████████████████████████████▍         |  ETA: 0:00:00
Sampling  78%|████████████████████████████████▊         |  ETA: 0:00:00
Sampling  79%|█████████████████████████████████▏        |  ETA: 0:00:00
Sampling  80%|█████████████████████████████████▋        |  ETA: 0:00:00
Sampling  81%|██████████████████████████████████        |  ETA: 0:00:00
Sampling  82%|██████████████████████████████████▌       |  ETA: 0:00:00
Sampling  83%|██████████████████████████████████▉       |  ETA: 0:00:00
Sampling  84%|███████████████████████████████████▎      |  ETA: 0:00:00
Sampling  85%|███████████████████████████████████▊      |  ETA: 0:00:00
Sampling  86%|████████████████████████████████████▏     |  ETA: 0:00:00
Sampling  87%|████████████████████████████████████▌     |  ETA: 0:00:00
Sampling  88%|█████████████████████████████████████     |  ETA: 0:00:00
Sampling  89%|█████████████████████████████████████▍    |  ETA: 0:00:00
Sampling  90%|█████████████████████████████████████▊    |  ETA: 0:00:00
Sampling  91%|██████████████████████████████████████▎   |  ETA: 0:00:00
Sampling  92%|██████████████████████████████████████▋   |  ETA: 0:00:00
Sampling  93%|███████████████████████████████████████   |  ETA: 0:00:00
Sampling  94%|███████████████████████████████████████▌  |  ETA: 0:00:00
Sampling  95%|███████████████████████████████████████▉  |  ETA: 0:00:00
Sampling  96%|████████████████████████████████████████▍ |  ETA: 0:00:00
Sampling  97%|████████████████████████████████████████▊ |  ETA: 0:00:00
Sampling  98%|█████████████████████████████████████████▏|  ETA: 0:00:00
Sampling  99%|█████████████████████████████████████████▋|  ETA: 0:00:00
Sampling 100%|██████████████████████████████████████████| Time: 0:00:01
Sampling 100%|██████████████████████████████████████████| Time: 0:00:03
Mean

  parameters      mean 
      Symbol   Float64 

           x    0.4997

However, to make this work with other samplers (and in particular HMC and NUTS), we also have to define the corresponding bijectors for our distribution. This is because HMC and NUTS operate in an unconstrained space.

Turing v0.43 onwards uses the Bijectors.VectorBijectors interface, which is documented here.

The most important functions to define are Bijectors.VectorBijectors.from_linked_vec and Bijectors.VectorBijectors.to_linked_vec, which define how to transform from the constrained space to the unconstrained space and back, respectively. On top of that, you also need to define ChangesOfVariables.with_logabsdet_jacobian for the resulting function. (Bijectors reexports that function for convenience.)

Specifically, to_linked_vec maps from samples (i.e. floats in [0, 1]) to a vector where every element is independent and unconstrained; and from_linked_vec is the inverse. Note that both these functions take the distribution as the argument, and return the transformation function! Then with_logabsdet_jacobian takes that function and its argument, and returns a tuple of the transformed value and the log-absolute-determinant of the Jacobian of the transformation.

NoteJacobians

If you are not familiar with Jacobians, this is explained in more detail at Variable Transformations.

This API is most easily explained by example. The forward transform (to_linked_vec) can be accomplished with a logit:

import Bijectors
using StatsFuns: logit, logistic, log1pexp

function veclogit(x::Real)
    return [logit(x)]
end

Bijectors.VectorBijectors.to_linked_vec(::CustomUniform) = veclogit

Bijectors.with_logabsdet_jacobian(::typeof(veclogit), x) = begin
    logit_x = logit(x)
    return [logit_x], -logit_x
end

Often it can be easier to create and dispatch on a callable struct; we’ll demonstrate that for the inverse transform.

import Bijectors

struct OnlyLogistic end
(::OnlyLogistic)(y::AbstractVector{<:Real}) = 1/(1 + exp(-y[1]))

Bijectors.VectorBijectors.from_linked_vec(::CustomUniform) = OnlyLogistic()

function Bijectors.with_logabsdet_jacobian(::OnlyLogistic, y::AbstractVector{<:Real})
    yi = y[]
    res = logistic(yi)
    logjac = yi - (2 * log1pexp(yi))
    return res, logjac
end

Once you have defined these, you should be able to sample with HMC and NUTS as well:

sample(demo(), NUTS(), 100)
Sampling   0%|                                          |  ETA: N/A
Info: Found initial step size
  ϵ = 3.25
Sampling   1%|▎                                         |  ETA: 0:04:47
Sampling   1%|▌                                         |  ETA: 0:03:13
Sampling   2%|▉                                         |  ETA: 0:02:11
Sampling   3%|█▏                                        |  ETA: 0:01:38
Sampling   3%|█▍                                        |  ETA: 0:01:17
Sampling   4%|█▋                                        |  ETA: 0:01:04
Sampling   5%|██                                        |  ETA: 0:00:55
Sampling   5%|██▎                                       |  ETA: 0:00:47
Sampling   6%|██▌                                       |  ETA: 0:00:42
Sampling   7%|██▊                                       |  ETA: 0:00:37
Sampling   7%|███▏                                      |  ETA: 0:00:34
Sampling   8%|███▍                                      |  ETA: 0:00:31
Sampling   9%|███▋                                      |  ETA: 0:00:28
Sampling   9%|███▉                                      |  ETA: 0:00:26
Sampling  10%|████▎                                     |  ETA: 0:00:24
Sampling  11%|████▌                                     |  ETA: 0:00:22
Sampling  11%|████▊                                     |  ETA: 0:00:21
Sampling  12%|█████                                     |  ETA: 0:00:20
Sampling  13%|█████▍                                    |  ETA: 0:00:18
Sampling  13%|█████▋                                    |  ETA: 0:00:17
Sampling  14%|█████▉                                    |  ETA: 0:00:16
Sampling  15%|██████▏                                   |  ETA: 0:00:16
Sampling  15%|██████▌                                   |  ETA: 0:00:15
Sampling  16%|██████▊                                   |  ETA: 0:00:14
Sampling  17%|███████                                   |  ETA: 0:00:13
Sampling  17%|███████▎                                  |  ETA: 0:00:13
Sampling  18%|███████▌                                  |  ETA: 0:00:12
Sampling  19%|███████▉                                  |  ETA: 0:00:12
Sampling  19%|████████▏                                 |  ETA: 0:00:11
Sampling  20%|████████▍                                 |  ETA: 0:00:11
Sampling  21%|████████▋                                 |  ETA: 0:00:10
Sampling  21%|█████████                                 |  ETA: 0:00:10
Sampling  22%|█████████▎                                |  ETA: 0:00:09
Sampling  23%|█████████▌                                |  ETA: 0:00:09
Sampling  23%|█████████▊                                |  ETA: 0:00:09
Sampling  24%|██████████▏                               |  ETA: 0:00:08
Sampling  25%|██████████▍                               |  ETA: 0:00:08
Sampling  25%|██████████▋                               |  ETA: 0:00:08
Sampling  26%|██████████▉                               |  ETA: 0:00:08
Sampling  27%|███████████▎                              |  ETA: 0:00:07
Sampling  27%|███████████▌                              |  ETA: 0:00:07
Sampling  28%|███████████▊                              |  ETA: 0:00:07
Sampling  29%|████████████                              |  ETA: 0:00:07
Sampling  29%|████████████▍                             |  ETA: 0:00:06
Sampling  30%|████████████▋                             |  ETA: 0:00:06
Sampling  31%|████████████▉                             |  ETA: 0:00:06
Sampling  31%|█████████████▏                            |  ETA: 0:00:06
Sampling  32%|█████████████▌                            |  ETA: 0:00:06
Sampling  33%|█████████████▊                            |  ETA: 0:00:06
Sampling  33%|██████████████                            |  ETA: 0:00:05
Sampling  34%|██████████████▎                           |  ETA: 0:00:05
Sampling  35%|██████████████▌                           |  ETA: 0:00:05
Sampling  35%|██████████████▉                           |  ETA: 0:00:05
Sampling  36%|███████████████▏                          |  ETA: 0:00:05
Sampling  37%|███████████████▍                          |  ETA: 0:00:05
Sampling  37%|███████████████▋                          |  ETA: 0:00:05
Sampling  38%|████████████████                          |  ETA: 0:00:05
Sampling  39%|████████████████▎                         |  ETA: 0:00:04
Sampling  39%|████████████████▌                         |  ETA: 0:00:04
Sampling  40%|████████████████▊                         |  ETA: 0:00:04
Sampling  41%|█████████████████▏                        |  ETA: 0:00:04
Sampling  41%|█████████████████▍                        |  ETA: 0:00:04
Sampling  42%|█████████████████▋                        |  ETA: 0:00:04
Sampling  43%|█████████████████▉                        |  ETA: 0:00:04
Sampling  43%|██████████████████▎                       |  ETA: 0:00:04
Sampling  44%|██████████████████▌                       |  ETA: 0:00:04
Sampling  45%|██████████████████▊                       |  ETA: 0:00:03
Sampling  45%|███████████████████                       |  ETA: 0:00:03
Sampling  46%|███████████████████▍                      |  ETA: 0:00:03
Sampling  47%|███████████████████▋                      |  ETA: 0:00:03
Sampling  47%|███████████████████▉                      |  ETA: 0:00:03
Sampling  48%|████████████████████▏                     |  ETA: 0:00:03
Sampling  49%|████████████████████▌                     |  ETA: 0:00:03
Sampling  49%|████████████████████▊                     |  ETA: 0:00:03
Sampling  50%|█████████████████████                     |  ETA: 0:00:03
Sampling  51%|█████████████████████▎                    |  ETA: 0:00:03
Sampling  51%|█████████████████████▌                    |  ETA: 0:00:03
Sampling  52%|█████████████████████▉                    |  ETA: 0:00:03
Sampling  53%|██████████████████████▏                   |  ETA: 0:00:02
Sampling  53%|██████████████████████▍                   |  ETA: 0:00:02
Sampling  54%|██████████████████████▋                   |  ETA: 0:00:02
Sampling  55%|███████████████████████                   |  ETA: 0:00:02
Sampling  55%|███████████████████████▎                  |  ETA: 0:00:02
Sampling  56%|███████████████████████▌                  |  ETA: 0:00:02
Sampling  57%|███████████████████████▊                  |  ETA: 0:00:02
Sampling  57%|████████████████████████▏                 |  ETA: 0:00:02
Sampling  58%|████████████████████████▍                 |  ETA: 0:00:02
Sampling  59%|████████████████████████▋                 |  ETA: 0:00:02
Sampling  59%|████████████████████████▉                 |  ETA: 0:00:02
Sampling  60%|█████████████████████████▎                |  ETA: 0:00:02
Sampling  61%|█████████████████████████▌                |  ETA: 0:00:02
Sampling  61%|█████████████████████████▊                |  ETA: 0:00:02
Sampling  62%|██████████████████████████                |  ETA: 0:00:02
Sampling  63%|██████████████████████████▍               |  ETA: 0:00:02
Sampling  63%|██████████████████████████▋               |  ETA: 0:00:02
Sampling  64%|██████████████████████████▉               |  ETA: 0:00:02
Sampling  65%|███████████████████████████▏              |  ETA: 0:00:02
Sampling  65%|███████████████████████████▌              |  ETA: 0:00:01
Sampling  66%|███████████████████████████▊              |  ETA: 0:00:01
Sampling  67%|████████████████████████████              |  ETA: 0:00:01
Sampling  67%|████████████████████████████▎             |  ETA: 0:00:01
Sampling  68%|████████████████████████████▌             |  ETA: 0:00:01
Sampling  69%|████████████████████████████▉             |  ETA: 0:00:01
Sampling  69%|█████████████████████████████▏            |  ETA: 0:00:01
Sampling  70%|█████████████████████████████▍            |  ETA: 0:00:01
Sampling  71%|█████████████████████████████▋            |  ETA: 0:00:01
Sampling  71%|██████████████████████████████            |  ETA: 0:00:01
Sampling  72%|██████████████████████████████▎           |  ETA: 0:00:01
Sampling  73%|██████████████████████████████▌           |  ETA: 0:00:01
Sampling  73%|██████████████████████████████▊           |  ETA: 0:00:01
Sampling  74%|███████████████████████████████▏          |  ETA: 0:00:01
Sampling  75%|███████████████████████████████▍          |  ETA: 0:00:01
Sampling  75%|███████████████████████████████▋          |  ETA: 0:00:01
Sampling  76%|███████████████████████████████▉          |  ETA: 0:00:01
Sampling  77%|████████████████████████████████▎         |  ETA: 0:00:01
Sampling  77%|████████████████████████████████▌         |  ETA: 0:00:01
Sampling  78%|████████████████████████████████▊         |  ETA: 0:00:01
Sampling  79%|█████████████████████████████████         |  ETA: 0:00:01
Sampling  79%|█████████████████████████████████▍        |  ETA: 0:00:01
Sampling  80%|█████████████████████████████████▋        |  ETA: 0:00:01
Sampling  81%|█████████████████████████████████▉        |  ETA: 0:00:01
Sampling  81%|██████████████████████████████████▏       |  ETA: 0:00:01
Sampling  82%|██████████████████████████████████▌       |  ETA: 0:00:01
Sampling  83%|██████████████████████████████████▊       |  ETA: 0:00:01
Sampling  83%|███████████████████████████████████       |  ETA: 0:00:01
Sampling  84%|███████████████████████████████████▎      |  ETA: 0:00:01
Sampling  85%|███████████████████████████████████▌      |  ETA: 0:00:01
Sampling  85%|███████████████████████████████████▉      |  ETA: 0:00:00
Sampling  86%|████████████████████████████████████▏     |  ETA: 0:00:00
Sampling  87%|████████████████████████████████████▍     |  ETA: 0:00:00
Sampling  87%|████████████████████████████████████▋     |  ETA: 0:00:00
Sampling  88%|█████████████████████████████████████     |  ETA: 0:00:00
Sampling  89%|█████████████████████████████████████▎    |  ETA: 0:00:00
Sampling  89%|█████████████████████████████████████▌    |  ETA: 0:00:00
Sampling  90%|█████████████████████████████████████▊    |  ETA: 0:00:00
Sampling  91%|██████████████████████████████████████▏   |  ETA: 0:00:00
Sampling  91%|██████████████████████████████████████▍   |  ETA: 0:00:00
Sampling  92%|██████████████████████████████████████▋   |  ETA: 0:00:00
Sampling  93%|██████████████████████████████████████▉   |  ETA: 0:00:00
Sampling  93%|███████████████████████████████████████▎  |  ETA: 0:00:00
Sampling  94%|███████████████████████████████████████▌  |  ETA: 0:00:00
Sampling  95%|███████████████████████████████████████▊  |  ETA: 0:00:00
Sampling  95%|████████████████████████████████████████  |  ETA: 0:00:00
Sampling  96%|████████████████████████████████████████▍ |  ETA: 0:00:00
Sampling  97%|████████████████████████████████████████▋ |  ETA: 0:00:00
Sampling  97%|████████████████████████████████████████▉ |  ETA: 0:00:00
Sampling  98%|█████████████████████████████████████████▏|  ETA: 0:00:00
Sampling  99%|█████████████████████████████████████████▌|  ETA: 0:00:00
Sampling  99%|█████████████████████████████████████████▊|  ETA: 0:00:00
Sampling 100%|██████████████████████████████████████████| Time: 0:00:02
Sampling 100%|██████████████████████████████████████████| Time: 0:00:03
Chains MCMC chain (100×15×1 Array{Float64, 3}):

Iterations        = 51:1:150
Number of chains  = 1
Samples per chain = 100
Wall duration     = 2.84 seconds
Compute duration  = 2.84 seconds
parameters        = x
internals         = n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, logprior, loglikelihood, logjoint

Use `describe(chains)` for summary statistics and quantiles.

There is a lot of functionality that is already in Bijectors.jl which you could reuse. In fact, for bounded univariate distributions Bijectors already defines default transforms, so in this specific instance the above code is not strictly necessary.

Back to top