using Distributions, Turing, RandomCustom Distributions
Turing.jl supports the use of distributions from the Distributions.jl package. By extension, it also supports the use of customised distributions by defining them as subtypes of the Distribution type in the Distributions.jl package, as well as corresponding functions.
This page shows a workflow of how to define a customised distribution, using our own implementation of a simple Uniform distribution as a simple example.
Distribution type
First, define a type of the distribution, as a subtype of a corresponding distribution type in the Distributions.jl package.
struct CustomUniform <: ContinuousUnivariateDistribution endSampling and log-probability evaluation
Second, implement the rand and logpdf functions for your new distribution, which will be used to run the model.
# sample in [0, 1]
Base.rand(rng::AbstractRNG, ::CustomUniform) = rand(rng)
# p(x) = 1 → log[p(x)] = 0
Distributions.logpdf(::CustomUniform, x::Real) = zero(x)Bijectors
Once you have defined the above, you should be able to use your distribution in a Turing model and sample with Prior().
@model function demo()
x ~ CustomUniform()
end
mean(sample(demo(), Prior(), 100))Sampling 0%| | ETA: N/A Sampling 1%|▍ | ETA: 0:02:26 Sampling 2%|▉ | ETA: 0:01:14 Sampling 3%|█▎ | ETA: 0:00:49 Sampling 4%|█▋ | ETA: 0:00:36 Sampling 5%|██▏ | ETA: 0:00:29 Sampling 6%|██▌ | ETA: 0:00:24 Sampling 7%|███ | ETA: 0:00:20 Sampling 8%|███▍ | ETA: 0:00:17 Sampling 9%|███▊ | ETA: 0:00:15 Sampling 10%|████▎ | ETA: 0:00:14 Sampling 11%|████▋ | ETA: 0:00:12 Sampling 12%|█████ | ETA: 0:00:11 Sampling 13%|█████▌ | ETA: 0:00:10 Sampling 14%|█████▉ | ETA: 0:00:09 Sampling 15%|██████▎ | ETA: 0:00:09 Sampling 16%|██████▊ | ETA: 0:00:08 Sampling 17%|███████▏ | ETA: 0:00:07 Sampling 18%|███████▌ | ETA: 0:00:07 Sampling 19%|████████ | ETA: 0:00:06 Sampling 20%|████████▍ | ETA: 0:00:06 Sampling 21%|████████▉ | ETA: 0:00:06 Sampling 22%|█████████▎ | ETA: 0:00:05 Sampling 23%|█████████▋ | ETA: 0:00:05 Sampling 24%|██████████▏ | ETA: 0:00:05 Sampling 25%|██████████▌ | ETA: 0:00:05 Sampling 26%|██████████▉ | ETA: 0:00:04 Sampling 27%|███████████▍ | ETA: 0:00:04 Sampling 28%|███████████▊ | ETA: 0:00:04 Sampling 29%|████████████▏ | ETA: 0:00:04 Sampling 30%|████████████▋ | ETA: 0:00:04 Sampling 31%|█████████████ | ETA: 0:00:03 Sampling 32%|█████████████▌ | ETA: 0:00:03 Sampling 33%|█████████████▉ | ETA: 0:00:03 Sampling 34%|██████████████▎ | ETA: 0:00:03 Sampling 35%|██████████████▊ | ETA: 0:00:03 Sampling 36%|███████████████▏ | ETA: 0:00:03 Sampling 37%|███████████████▌ | ETA: 0:00:03 Sampling 38%|████████████████ | ETA: 0:00:02 Sampling 39%|████████████████▍ | ETA: 0:00:02 Sampling 40%|████████████████▊ | ETA: 0:00:02 Sampling 41%|█████████████████▎ | ETA: 0:00:02 Sampling 42%|█████████████████▋ | ETA: 0:00:02 Sampling 43%|██████████████████ | ETA: 0:00:02 Sampling 44%|██████████████████▌ | ETA: 0:00:02 Sampling 45%|██████████████████▉ | ETA: 0:00:02 Sampling 46%|███████████████████▍ | ETA: 0:00:02 Sampling 47%|███████████████████▊ | ETA: 0:00:02 Sampling 48%|████████████████████▏ | ETA: 0:00:02 Sampling 49%|████████████████████▋ | ETA: 0:00:02 Sampling 50%|█████████████████████ | ETA: 0:00:02 Sampling 51%|█████████████████████▍ | ETA: 0:00:01 Sampling 52%|█████████████████████▉ | ETA: 0:00:01 Sampling 53%|██████████████████████▎ | ETA: 0:00:01 Sampling 54%|██████████████████████▋ | ETA: 0:00:01 Sampling 55%|███████████████████████▏ | ETA: 0:00:01 Sampling 56%|███████████████████████▌ | ETA: 0:00:01 Sampling 57%|████████████████████████ | ETA: 0:00:01 Sampling 58%|████████████████████████▍ | ETA: 0:00:01 Sampling 59%|████████████████████████▊ | ETA: 0:00:01 Sampling 60%|█████████████████████████▎ | ETA: 0:00:01 Sampling 61%|█████████████████████████▋ | ETA: 0:00:01 Sampling 62%|██████████████████████████ | ETA: 0:00:01 Sampling 63%|██████████████████████████▌ | ETA: 0:00:01 Sampling 64%|██████████████████████████▉ | ETA: 0:00:01 Sampling 65%|███████████████████████████▎ | ETA: 0:00:01 Sampling 66%|███████████████████████████▊ | ETA: 0:00:01 Sampling 67%|████████████████████████████▏ | ETA: 0:00:01 Sampling 68%|████████████████████████████▌ | ETA: 0:00:01 Sampling 69%|█████████████████████████████ | ETA: 0:00:01 Sampling 70%|█████████████████████████████▍ | ETA: 0:00:01 Sampling 71%|█████████████████████████████▉ | ETA: 0:00:01 Sampling 72%|██████████████████████████████▎ | ETA: 0:00:01 Sampling 73%|██████████████████████████████▋ | ETA: 0:00:01 Sampling 74%|███████████████████████████████▏ | ETA: 0:00:01 Sampling 75%|███████████████████████████████▌ | ETA: 0:00:01 Sampling 76%|███████████████████████████████▉ | ETA: 0:00:00 Sampling 77%|████████████████████████████████▍ | ETA: 0:00:00 Sampling 78%|████████████████████████████████▊ | ETA: 0:00:00 Sampling 79%|█████████████████████████████████▏ | ETA: 0:00:00 Sampling 80%|█████████████████████████████████▋ | ETA: 0:00:00 Sampling 81%|██████████████████████████████████ | ETA: 0:00:00 Sampling 82%|██████████████████████████████████▌ | ETA: 0:00:00 Sampling 83%|██████████████████████████████████▉ | ETA: 0:00:00 Sampling 84%|███████████████████████████████████▎ | ETA: 0:00:00 Sampling 85%|███████████████████████████████████▊ | ETA: 0:00:00 Sampling 86%|████████████████████████████████████▏ | ETA: 0:00:00 Sampling 87%|████████████████████████████████████▌ | ETA: 0:00:00 Sampling 88%|█████████████████████████████████████ | ETA: 0:00:00 Sampling 89%|█████████████████████████████████████▍ | ETA: 0:00:00 Sampling 90%|█████████████████████████████████████▊ | ETA: 0:00:00 Sampling 91%|██████████████████████████████████████▎ | ETA: 0:00:00 Sampling 92%|██████████████████████████████████████▋ | ETA: 0:00:00 Sampling 93%|███████████████████████████████████████ | ETA: 0:00:00 Sampling 94%|███████████████████████████████████████▌ | ETA: 0:00:00 Sampling 95%|███████████████████████████████████████▉ | ETA: 0:00:00 Sampling 96%|████████████████████████████████████████▍ | ETA: 0:00:00 Sampling 97%|████████████████████████████████████████▊ | ETA: 0:00:00 Sampling 98%|█████████████████████████████████████████▏| ETA: 0:00:00 Sampling 99%|█████████████████████████████████████████▋| ETA: 0:00:00 Sampling 100%|██████████████████████████████████████████| Time: 0:00:01 Sampling 100%|██████████████████████████████████████████| Time: 0:00:03
Mean parameters mean Symbol Float64 x 0.4997
However, to make this work with other samplers (and in particular HMC and NUTS), we also have to define the corresponding bijectors for our distribution. This is because HMC and NUTS operate in an unconstrained space.
Turing v0.43 onwards uses the Bijectors.VectorBijectors interface, which is documented here.
The most important functions to define are Bijectors.VectorBijectors.from_linked_vec and Bijectors.VectorBijectors.to_linked_vec, which define how to transform from the constrained space to the unconstrained space and back, respectively. On top of that, you also need to define ChangesOfVariables.with_logabsdet_jacobian for the resulting function. (Bijectors reexports that function for convenience.)
Specifically, to_linked_vec maps from samples (i.e. floats in [0, 1]) to a vector where every element is independent and unconstrained; and from_linked_vec is the inverse. Note that both these functions take the distribution as the argument, and return the transformation function! Then with_logabsdet_jacobian takes that function and its argument, and returns a tuple of the transformed value and the log-absolute-determinant of the Jacobian of the transformation.
If you are not familiar with Jacobians, this is explained in more detail at Variable Transformations.
This API is most easily explained by example. The forward transform (to_linked_vec) can be accomplished with a logit:
import Bijectors
using StatsFuns: logit, logistic, log1pexp
function veclogit(x::Real)
return [logit(x)]
end
Bijectors.VectorBijectors.to_linked_vec(::CustomUniform) = veclogit
Bijectors.with_logabsdet_jacobian(::typeof(veclogit), x) = begin
logit_x = logit(x)
return [logit_x], -logit_x
endOften it can be easier to create and dispatch on a callable struct; we’ll demonstrate that for the inverse transform.
import Bijectors
struct OnlyLogistic end
(::OnlyLogistic)(y::AbstractVector{<:Real}) = 1/(1 + exp(-y[1]))
Bijectors.VectorBijectors.from_linked_vec(::CustomUniform) = OnlyLogistic()
function Bijectors.with_logabsdet_jacobian(::OnlyLogistic, y::AbstractVector{<:Real})
yi = y[]
res = logistic(yi)
logjac = yi - (2 * log1pexp(yi))
return res, logjac
endOnce you have defined these, you should be able to sample with HMC and NUTS as well:
sample(demo(), NUTS(), 100)Sampling 0%| | ETA: N/A ┌ Info: Found initial step size └ ϵ = 3.25 Sampling 1%|▎ | ETA: 0:04:47 Sampling 1%|▌ | ETA: 0:03:13 Sampling 2%|▉ | ETA: 0:02:11 Sampling 3%|█▏ | ETA: 0:01:38 Sampling 3%|█▍ | ETA: 0:01:17 Sampling 4%|█▋ | ETA: 0:01:04 Sampling 5%|██ | ETA: 0:00:55 Sampling 5%|██▎ | ETA: 0:00:47 Sampling 6%|██▌ | ETA: 0:00:42 Sampling 7%|██▊ | ETA: 0:00:37 Sampling 7%|███▏ | ETA: 0:00:34 Sampling 8%|███▍ | ETA: 0:00:31 Sampling 9%|███▋ | ETA: 0:00:28 Sampling 9%|███▉ | ETA: 0:00:26 Sampling 10%|████▎ | ETA: 0:00:24 Sampling 11%|████▌ | ETA: 0:00:22 Sampling 11%|████▊ | ETA: 0:00:21 Sampling 12%|█████ | ETA: 0:00:20 Sampling 13%|█████▍ | ETA: 0:00:18 Sampling 13%|█████▋ | ETA: 0:00:17 Sampling 14%|█████▉ | ETA: 0:00:16 Sampling 15%|██████▏ | ETA: 0:00:16 Sampling 15%|██████▌ | ETA: 0:00:15 Sampling 16%|██████▊ | ETA: 0:00:14 Sampling 17%|███████ | ETA: 0:00:13 Sampling 17%|███████▎ | ETA: 0:00:13 Sampling 18%|███████▌ | ETA: 0:00:12 Sampling 19%|███████▉ | ETA: 0:00:12 Sampling 19%|████████▏ | ETA: 0:00:11 Sampling 20%|████████▍ | ETA: 0:00:11 Sampling 21%|████████▋ | ETA: 0:00:10 Sampling 21%|█████████ | ETA: 0:00:10 Sampling 22%|█████████▎ | ETA: 0:00:09 Sampling 23%|█████████▌ | ETA: 0:00:09 Sampling 23%|█████████▊ | ETA: 0:00:09 Sampling 24%|██████████▏ | ETA: 0:00:08 Sampling 25%|██████████▍ | ETA: 0:00:08 Sampling 25%|██████████▋ | ETA: 0:00:08 Sampling 26%|██████████▉ | ETA: 0:00:08 Sampling 27%|███████████▎ | ETA: 0:00:07 Sampling 27%|███████████▌ | ETA: 0:00:07 Sampling 28%|███████████▊ | ETA: 0:00:07 Sampling 29%|████████████ | ETA: 0:00:07 Sampling 29%|████████████▍ | ETA: 0:00:06 Sampling 30%|████████████▋ | ETA: 0:00:06 Sampling 31%|████████████▉ | ETA: 0:00:06 Sampling 31%|█████████████▏ | ETA: 0:00:06 Sampling 32%|█████████████▌ | ETA: 0:00:06 Sampling 33%|█████████████▊ | ETA: 0:00:06 Sampling 33%|██████████████ | ETA: 0:00:05 Sampling 34%|██████████████▎ | ETA: 0:00:05 Sampling 35%|██████████████▌ | ETA: 0:00:05 Sampling 35%|██████████████▉ | ETA: 0:00:05 Sampling 36%|███████████████▏ | ETA: 0:00:05 Sampling 37%|███████████████▍ | ETA: 0:00:05 Sampling 37%|███████████████▋ | ETA: 0:00:05 Sampling 38%|████████████████ | ETA: 0:00:05 Sampling 39%|████████████████▎ | ETA: 0:00:04 Sampling 39%|████████████████▌ | ETA: 0:00:04 Sampling 40%|████████████████▊ | ETA: 0:00:04 Sampling 41%|█████████████████▏ | ETA: 0:00:04 Sampling 41%|█████████████████▍ | ETA: 0:00:04 Sampling 42%|█████████████████▋ | ETA: 0:00:04 Sampling 43%|█████████████████▉ | ETA: 0:00:04 Sampling 43%|██████████████████▎ | ETA: 0:00:04 Sampling 44%|██████████████████▌ | ETA: 0:00:04 Sampling 45%|██████████████████▊ | ETA: 0:00:03 Sampling 45%|███████████████████ | ETA: 0:00:03 Sampling 46%|███████████████████▍ | ETA: 0:00:03 Sampling 47%|███████████████████▋ | ETA: 0:00:03 Sampling 47%|███████████████████▉ | ETA: 0:00:03 Sampling 48%|████████████████████▏ | ETA: 0:00:03 Sampling 49%|████████████████████▌ | ETA: 0:00:03 Sampling 49%|████████████████████▊ | ETA: 0:00:03 Sampling 50%|█████████████████████ | ETA: 0:00:03 Sampling 51%|█████████████████████▎ | ETA: 0:00:03 Sampling 51%|█████████████████████▌ | ETA: 0:00:03 Sampling 52%|█████████████████████▉ | ETA: 0:00:03 Sampling 53%|██████████████████████▏ | ETA: 0:00:02 Sampling 53%|██████████████████████▍ | ETA: 0:00:02 Sampling 54%|██████████████████████▋ | ETA: 0:00:02 Sampling 55%|███████████████████████ | ETA: 0:00:02 Sampling 55%|███████████████████████▎ | ETA: 0:00:02 Sampling 56%|███████████████████████▌ | ETA: 0:00:02 Sampling 57%|███████████████████████▊ | ETA: 0:00:02 Sampling 57%|████████████████████████▏ | ETA: 0:00:02 Sampling 58%|████████████████████████▍ | ETA: 0:00:02 Sampling 59%|████████████████████████▋ | ETA: 0:00:02 Sampling 59%|████████████████████████▉ | ETA: 0:00:02 Sampling 60%|█████████████████████████▎ | ETA: 0:00:02 Sampling 61%|█████████████████████████▌ | ETA: 0:00:02 Sampling 61%|█████████████████████████▊ | ETA: 0:00:02 Sampling 62%|██████████████████████████ | ETA: 0:00:02 Sampling 63%|██████████████████████████▍ | ETA: 0:00:02 Sampling 63%|██████████████████████████▋ | ETA: 0:00:02 Sampling 64%|██████████████████████████▉ | ETA: 0:00:02 Sampling 65%|███████████████████████████▏ | ETA: 0:00:02 Sampling 65%|███████████████████████████▌ | ETA: 0:00:01 Sampling 66%|███████████████████████████▊ | ETA: 0:00:01 Sampling 67%|████████████████████████████ | ETA: 0:00:01 Sampling 67%|████████████████████████████▎ | ETA: 0:00:01 Sampling 68%|████████████████████████████▌ | ETA: 0:00:01 Sampling 69%|████████████████████████████▉ | ETA: 0:00:01 Sampling 69%|█████████████████████████████▏ | ETA: 0:00:01 Sampling 70%|█████████████████████████████▍ | ETA: 0:00:01 Sampling 71%|█████████████████████████████▋ | ETA: 0:00:01 Sampling 71%|██████████████████████████████ | ETA: 0:00:01 Sampling 72%|██████████████████████████████▎ | ETA: 0:00:01 Sampling 73%|██████████████████████████████▌ | ETA: 0:00:01 Sampling 73%|██████████████████████████████▊ | ETA: 0:00:01 Sampling 74%|███████████████████████████████▏ | ETA: 0:00:01 Sampling 75%|███████████████████████████████▍ | ETA: 0:00:01 Sampling 75%|███████████████████████████████▋ | ETA: 0:00:01 Sampling 76%|███████████████████████████████▉ | ETA: 0:00:01 Sampling 77%|████████████████████████████████▎ | ETA: 0:00:01 Sampling 77%|████████████████████████████████▌ | ETA: 0:00:01 Sampling 78%|████████████████████████████████▊ | ETA: 0:00:01 Sampling 79%|█████████████████████████████████ | ETA: 0:00:01 Sampling 79%|█████████████████████████████████▍ | ETA: 0:00:01 Sampling 80%|█████████████████████████████████▋ | ETA: 0:00:01 Sampling 81%|█████████████████████████████████▉ | ETA: 0:00:01 Sampling 81%|██████████████████████████████████▏ | ETA: 0:00:01 Sampling 82%|██████████████████████████████████▌ | ETA: 0:00:01 Sampling 83%|██████████████████████████████████▊ | ETA: 0:00:01 Sampling 83%|███████████████████████████████████ | ETA: 0:00:01 Sampling 84%|███████████████████████████████████▎ | ETA: 0:00:01 Sampling 85%|███████████████████████████████████▌ | ETA: 0:00:01 Sampling 85%|███████████████████████████████████▉ | ETA: 0:00:00 Sampling 86%|████████████████████████████████████▏ | ETA: 0:00:00 Sampling 87%|████████████████████████████████████▍ | ETA: 0:00:00 Sampling 87%|████████████████████████████████████▋ | ETA: 0:00:00 Sampling 88%|█████████████████████████████████████ | ETA: 0:00:00 Sampling 89%|█████████████████████████████████████▎ | ETA: 0:00:00 Sampling 89%|█████████████████████████████████████▌ | ETA: 0:00:00 Sampling 90%|█████████████████████████████████████▊ | ETA: 0:00:00 Sampling 91%|██████████████████████████████████████▏ | ETA: 0:00:00 Sampling 91%|██████████████████████████████████████▍ | ETA: 0:00:00 Sampling 92%|██████████████████████████████████████▋ | ETA: 0:00:00 Sampling 93%|██████████████████████████████████████▉ | ETA: 0:00:00 Sampling 93%|███████████████████████████████████████▎ | ETA: 0:00:00 Sampling 94%|███████████████████████████████████████▌ | ETA: 0:00:00 Sampling 95%|███████████████████████████████████████▊ | ETA: 0:00:00 Sampling 95%|████████████████████████████████████████ | ETA: 0:00:00 Sampling 96%|████████████████████████████████████████▍ | ETA: 0:00:00 Sampling 97%|████████████████████████████████████████▋ | ETA: 0:00:00 Sampling 97%|████████████████████████████████████████▉ | ETA: 0:00:00 Sampling 98%|█████████████████████████████████████████▏| ETA: 0:00:00 Sampling 99%|█████████████████████████████████████████▌| ETA: 0:00:00 Sampling 99%|█████████████████████████████████████████▊| ETA: 0:00:00 Sampling 100%|██████████████████████████████████████████| Time: 0:00:02 Sampling 100%|██████████████████████████████████████████| Time: 0:00:03
Chains MCMC chain (100×15×1 Array{Float64, 3}):
Iterations = 51:1:150
Number of chains = 1
Samples per chain = 100
Wall duration = 2.84 seconds
Compute duration = 2.84 seconds
parameters = x
internals = n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, logprior, loglikelihood, logjoint
Use `describe(chains)` for summary statistics and quantiles.
There is a lot of functionality that is already in Bijectors.jl which you could reuse. In fact, for bounded univariate distributions Bijectors already defines default transforms, so in this specific instance the above code is not strictly necessary.