Querying Model Probabilities

The easiest way to manipulate and query Turing models is via the DynamicPPL probability interface.

Let’s use a simple model of normally-distributed data as an example.

using Turing
using DynamicPPL
using Random

@model function gdemo(n)
    μ ~ Normal(0, 1)
    x ~ MvNormal(fill(μ, n), I)
end
Precompiling Turing...
    804.2 ms  ? OptimizationBase
   1403.5 ms  ? Optimization
   2134.8 ms  ? OptimizationOptimJL
Info Given Turing was explicitly requested, output will be shown live 
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
   5519.3 ms  ? Turing
   5669.4 ms  ? Turing → TuringOptimExt
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Optimization...
    806.7 ms  ? OptimizationBase
Info Given Optimization was explicitly requested, output will be shown live 
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
   1391.7 ms  ? Optimization
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling OptimizationBase...
Info Given OptimizationBase was explicitly requested, output will be shown live 
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
    793.7 ms  ? OptimizationBase
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
Warning: Replacing docs for `CommonSolve.solve :: Tuple{SciMLBase.OptimizationProblem, Any, Vararg{Any}}` in module `OptimizationBase`
@ Base.Docs docs/Docs.jl:243
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
Warning: Replacing docs for `CommonSolve.init :: Tuple{SciMLBase.OptimizationProblem, Any, Vararg{Any}}` in module `OptimizationBase`
@ Base.Docs docs/Docs.jl:243
Warning: Replacing docs for `CommonSolve.solve! :: Tuple{SciMLBase.AbstractOptimizationCache}` in module `OptimizationBase`
@ Base.Docs docs/Docs.jl:243
Precompiling OptimizationOptimJL...
    763.5 ms  ? OptimizationBase
    985.9 ms  ? Optimization
Info Given OptimizationOptimJL was explicitly requested, output will be shown live 
Warning: Module Optimization with build ID ffffffff-ffff-ffff-d307-97078644726e is missing from the cache.
This may mean Optimization [7f7a1694-90dd-40f0-9382-eb1efda571ba] does not support precompilation but is imported by a module that does.
@ Base loading.jl:2541
   1122.7 ms  ? OptimizationOptimJL
Warning: Module Optimization with build ID ffffffff-ffff-ffff-d307-97078644726e is missing from the cache.
This may mean Optimization [7f7a1694-90dd-40f0-9382-eb1efda571ba] does not support precompilation but is imported by a module that does.
@ Base loading.jl:2541
Precompiling TuringOptimExt...
    764.6 ms  ? OptimizationBase
    977.0 ms  ? Optimization
   1114.8 ms  ? OptimizationOptimJL
   3675.4 ms  ? Turing
Info Given TuringOptimExt was explicitly requested, output will be shown live 
Warning: Module Turing with build ID ffffffff-ffff-ffff-478f-bde97b75951f is missing from the cache.
This may mean Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] does not support precompilation but is imported by a module that does.
@ Base loading.jl:2541
    600.9 ms  ? Turing → TuringOptimExt
Warning: Module Turing with build ID ffffffff-ffff-ffff-478f-bde97b75951f is missing from the cache.
This may mean Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] does not support precompilation but is imported by a module that does.
@ Base loading.jl:2541
gdemo (generic function with 2 methods)

We generate some data using μ = 0:

Random.seed!(1776)
dataset = randn(100)
dataset[1:5]
5-element Vector{Float64}:
  0.8488780584442736
 -0.31936138249336765
 -1.3982098801744465
 -0.05198933163879332
 -1.1465116601038348

Conditioning and Deconditioning

Bayesian models can be transformed with two main operations, conditioning and deconditioning (also known as marginalization). Conditioning takes a variable and fixes its value as known. We do this by passing a model and a collection of conditioned variables to |, or its alias, condition:

# (equivalently)
# conditioned_model = condition(gdemo(length(dataset)), (x=dataset, μ=0))
conditioned_model = gdemo(length(dataset)) | (x=dataset, μ=0)
Model{typeof(gdemo), (:n,), (), (), Tuple{Int64}, Tuple{}, ConditionContext{@NamedTuple{x::Vector{Float64}, μ::Int64}, DefaultContext}}(gdemo, (n = 100,), NamedTuple(), ConditionContext((x = [0.8488780584442736, -0.31936138249336765, -1.3982098801744465, -0.05198933163879332, -1.1465116601038348, -0.6306168227545849, 0.6862766694322289, -0.5485073478947856, -0.17212004616875684, 1.2883226251958486, -0.13661316034377538, 2.4316115122026973, 0.2251319215717449, -0.5115708179083417, -0.7810712258995324, -1.0191704692490737, 1.1210038448250719, -1.6944509713762377, -0.27314823183454695, 0.25273963222687423, 1.3914215917992434, 0.7525340831125464, 0.847154387311101, -0.7130402796655171, 0.2983575202861233, -0.1785631526879386, 0.08659477535701691, -0.5167265137098563, 2.111309740316035, 0.3957655443124509, -0.0804390853521051, 1.255042471667049, -0.07882822403959532, 1.2261373761992618, 0.43953618247769816, -0.40640013183427787, -0.6868635949523503, 1.7380713294668497, 0.13685965156352295, 0.1485185624825999, -0.7798816720822024, 2.2595105995080846, -0.13609014938597142, 0.22785777205259913, -2.1005250433485725, 0.44205288222935385, -1.238456637875994, -2.3727125492433427, -0.24406624959402184, -0.04488042525902438, 0.27510026183444175, 0.42472846594528796, 1.0337924022589282, 0.9126364433535069, -0.9006583845907805, 0.8665471057463393, 1.4924737539852484, 1.2886591566091432, 1.037264411147446, 1.4731954133339449, -0.31874662373651885, 1.2255399151799211, -1.6642044048811695, -0.5717328092786154, -1.2700237196779645, 0.5748199649058684, 0.16467729820692942, -1.195290550625328, -0.37133526877621703, -0.3018979982049836, -2.0183406292097397, -0.9588803575112745, 0.7177183994733006, -1.0133440177662316, -1.0881357990941283, 1.0487446580734279, 2.627227367991459, -1.59963908284846, -0.3122512299247273, -1.0265333654194488, 0.5557085182114885, -0.3206725445321106, -1.4314746067673778, 1.5740113510560039, -0.6566477752702335, 0.31342313477927125, 0.33135361418686027, -1.0489180508346863, -0.2670759024309527, 0.4683952221006179, 0.04918061587657951, 1.239814741442417, 2.2239462179369296, 1.8507671783064434, 1.756319462015174, -0.6577450354719728, 2.2795431083561626, -0.492273906928334, 0.7045614632761499, 0.11260553216111485], μ = 0), DynamicPPL.DefaultContext()))

This operation can be reversed by applying decondition:

original_model = decondition(conditioned_model)
Model{typeof(gdemo), (:n,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}(gdemo, (n = 100,), NamedTuple(), DefaultContext())

We can also decondition only some of the variables:

partially_conditioned = decondition(conditioned_model, :μ)
Model{typeof(gdemo), (:n,), (), (), Tuple{Int64}, Tuple{}, ConditionContext{@NamedTuple{x::Vector{Float64}}, DefaultContext}}(gdemo, (n = 100,), NamedTuple(), ConditionContext((x = [0.8488780584442736, -0.31936138249336765, -1.3982098801744465, -0.05198933163879332, -1.1465116601038348, -0.6306168227545849, 0.6862766694322289, -0.5485073478947856, -0.17212004616875684, 1.2883226251958486, -0.13661316034377538, 2.4316115122026973, 0.2251319215717449, -0.5115708179083417, -0.7810712258995324, -1.0191704692490737, 1.1210038448250719, -1.6944509713762377, -0.27314823183454695, 0.25273963222687423, 1.3914215917992434, 0.7525340831125464, 0.847154387311101, -0.7130402796655171, 0.2983575202861233, -0.1785631526879386, 0.08659477535701691, -0.5167265137098563, 2.111309740316035, 0.3957655443124509, -0.0804390853521051, 1.255042471667049, -0.07882822403959532, 1.2261373761992618, 0.43953618247769816, -0.40640013183427787, -0.6868635949523503, 1.7380713294668497, 0.13685965156352295, 0.1485185624825999, -0.7798816720822024, 2.2595105995080846, -0.13609014938597142, 0.22785777205259913, -2.1005250433485725, 0.44205288222935385, -1.238456637875994, -2.3727125492433427, -0.24406624959402184, -0.04488042525902438, 0.27510026183444175, 0.42472846594528796, 1.0337924022589282, 0.9126364433535069, -0.9006583845907805, 0.8665471057463393, 1.4924737539852484, 1.2886591566091432, 1.037264411147446, 1.4731954133339449, -0.31874662373651885, 1.2255399151799211, -1.6642044048811695, -0.5717328092786154, -1.2700237196779645, 0.5748199649058684, 0.16467729820692942, -1.195290550625328, -0.37133526877621703, -0.3018979982049836, -2.0183406292097397, -0.9588803575112745, 0.7177183994733006, -1.0133440177662316, -1.0881357990941283, 1.0487446580734279, 2.627227367991459, -1.59963908284846, -0.3122512299247273, -1.0265333654194488, 0.5557085182114885, -0.3206725445321106, -1.4314746067673778, 1.5740113510560039, -0.6566477752702335, 0.31342313477927125, 0.33135361418686027, -1.0489180508346863, -0.2670759024309527, 0.4683952221006179, 0.04918061587657951, 1.239814741442417, 2.2239462179369296, 1.8507671783064434, 1.756319462015174, -0.6577450354719728, 2.2795431083561626, -0.492273906928334, 0.7045614632761499, 0.11260553216111485],), DynamicPPL.DefaultContext()))

We can see which of the variables in a model have been conditioned with DynamicPPL.conditioned:

DynamicPPL.conditioned(partially_conditioned)
(x = [0.8488780584442736, -0.31936138249336765, -1.3982098801744465, -0.05198933163879332, -1.1465116601038348, -0.6306168227545849, 0.6862766694322289, -0.5485073478947856, -0.17212004616875684, 1.2883226251958486  …  0.04918061587657951, 1.239814741442417, 2.2239462179369296, 1.8507671783064434, 1.756319462015174, -0.6577450354719728, 2.2795431083561626, -0.492273906928334, 0.7045614632761499, 0.11260553216111485],)
Note

Sometimes it is helpful to define convenience functions for conditioning on some variable(s). For instance, in this example we might want to define a version of gdemo that conditions on some observations of x:

gdemo(x::AbstractVector{<:Real}) = gdemo(length(x)) | (; x)

For illustrative purposes, however, we do not use this function in the examples below.

Probabilities and Densities

We often want to calculate the (unnormalized) probability density for an event. This probability might be a prior, a likelihood, or a posterior (joint) density. DynamicPPL provides convenient functions for this. To begin, let’s define a model gdemo, condition it on a dataset, and draw a sample. The returned sample only contains μ, since the value of x has already been fixed:

model = gdemo(length(dataset)) | (x=dataset,)

Random.seed!(124)
sample = rand(model)
(μ = -0.6680014719649068,)

We can then calculate the joint probability of a set of samples (here drawn from the prior) with logjoint.

logjoint(model, sample)
-181.7247437162069

For models with many variables rand(model) can be prohibitively slow since it returns a NamedTuple of samples from the prior distribution of the unconditioned variables. We recommend working with samples of type DataStructures.OrderedDict in this case (which Turing re-exports, so can be used directly):

Random.seed!(124)
sample_dict = rand(OrderedDict, model)
OrderedDict{VarName, Any} with 1 entry:
  μ => -0.668001

logjoint can also be used on this sample:

logjoint(model, sample_dict)
-181.7247437162069

The prior probability and the likelihood of a set of samples can be calculated with the functions logprior and loglikelihood respectively. The log joint probability is the sum of these two quantities:

logjoint(model, sample)  loglikelihood(model, sample) + logprior(model, sample)
true
logjoint(model, sample_dict)  loglikelihood(model, sample_dict) + logprior(model, sample_dict)
true

Example: Cross-validation

To give an example of the probability interface in use, we can use it to estimate the performance of our model using cross-validation. In cross-validation, we split the dataset into several equal parts. Then, we choose one of these sets to serve as the validation set. Here, we measure fit using the cross entropy (Bayes loss).1 (For the sake of simplicity, in the following code, we enforce that nfolds must divide the number of data points. For a more competent implementation, see MLUtils.jl.)

# Calculate the train/validation splits across `nfolds` partitions, assume `length(dataset)` divides `nfolds`
function kfolds(dataset::Array{<:Real}, nfolds::Int)
    fold_size, remaining = divrem(length(dataset), nfolds)
    if remaining != 0
        error("The number of folds must divide the number of data points.")
    end
    first_idx = firstindex(dataset)
    last_idx = lastindex(dataset)
    splits = map(0:(nfolds - 1)) do i
        start_idx = first_idx + i * fold_size
        end_idx = start_idx + fold_size
        train_set_indices = [first_idx:(start_idx - 1); end_idx:last_idx]
        return (view(dataset, train_set_indices), view(dataset, start_idx:(end_idx - 1)))
    end
    return splits
end

function cross_val(
    dataset::Vector{<:Real};
    nfolds::Int=5,
    nsamples::Int=1_000,
    rng::Random.AbstractRNG=Random.default_rng(),
)
    # Initialize `loss` in a way such that the loop below does not change its type
    model = gdemo(1) | (x=[first(dataset)],)
    loss = zero(logjoint(model, rand(rng, model)))

    for (train, validation) in kfolds(dataset, nfolds)
        # First, we train the model on the training set, i.e., we obtain samples from the posterior.
        # For normally-distributed data, the posterior can be computed in closed form.
        # For general models, however, typically samples will be generated using MCMC with Turing.
        posterior = Normal(mean(train), 1)
        samples = rand(rng, posterior, nsamples)

        # Evaluation on the validation set.
        validation_model = gdemo(length(validation)) | (x=validation,)
        loss += sum(samples) do sample
            logjoint(validation_model, (μ=sample,))
        end
    end

    return loss
end

cross_val(dataset)
-212760.30282411768
Back to top

Footnotes

  1. See ParetoSmooth.jl for a faster and more accurate implementation of cross-validation than the one provided here.↩︎