Threadsafe Evaluation

A common technique to speed up Julia code is to use multiple threads to run computations in parallel. The Julia manual has a section on multithreading, which is a good introduction to the topic.

We assume that the reader is familiar with some threading constructs in Julia, and the general concept of data races. This page specificaly discusses Turing’s support for threadsafe model evaluation.

Note

Please note that this is a rapidly-moving topic, and things may change in future releases of Turing. If you are ever unsure about what works and doesn’t, please don’t hesitate to ask on Slack or Discourse

MCMC sampling

For complete clarity, this page has nothing to do with parallel sampling of MCMC chains using

sample(model, sampler, MCMCThreads(), N, nchains)

That parallelisation exists outside of the model evaluation, and thus is independent of the model contents. This page only discusses threading inside Turing models.

Threading in Turing models

Given that Turing models mostly contain ‘plain’ Julia code, one might expect that all threading constructs such as Threads.@threads or Threads.@spawn can be used inside Turing models.

This is, to some extent, true: for example, you can use threading constructs to speed up deterministic computations. For example, here we use parallelism to speed up a transformation of x:

@model function f(y)
    x ~ dist
    x_transformed = similar(x)
    Threads.@threads for i in eachindex(x)
        x_transformed[i] = some_expensive_function(x[i])
    end
    y ~ some_likelihood(x_transformed)
end

In general, for code that does not involve tilde-statements (x ~ dist), threading works exactly as it does in regular Julia code.

However, extra care must be taken when using tilde-statements (x ~ dist) inside threaded blocks. The reason for this is because tilde-statements modify the internal VarInfo object used for model evaluation. Essentially, x ~ dist expands to something like

x, __varinfo__ = DynamicPPL.tilde_assume!!(..., __varinfo__)

and writing into __varinfo__ is, in general, not threadsafe. Thus, parallelising tilde-statements can lead to data races as described in the Julia manual.

Threaded tilde-statements

As of version 0.41, Turing only supports the use of tilde-statements inside threaded blocks when these are observations (i.e., likelihood terms).

This means that the following code is safe to use:

using Turing 

@model function threaded_obs(N)
    x ~ Normal()
    y = Vector{Float64}(undef, N)
    Threads.@threads for i in 1:N
        y[i] ~ Normal(x)
    end
end

N = 100
y = randn(N)
model = threaded_obs(N) | (; y = y)
DynamicPPL.Model{typeof(threaded_obs), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.ConditionContext{@NamedTuple{y::Vector{Float64}}, DynamicPPL.DefaultContext}}(threaded_obs, (N = 100,), NamedTuple(), ConditionContext((y = [-0.9168068763559645, -0.0787007409910791, -0.7820255197629912, -1.5972834830676683, -0.030627682115694322, -0.6829221807121033, 0.26246360761943877, -0.7867199909096809, -0.028168152597328227, -0.4102003665185203, -0.6392788108604717, -0.06088914610690609, -1.898145355550453, -1.1626010549338301, 0.898181841772956, 0.9767162300293398, 0.7203302338173616, 1.6149729462680735, 0.05491166631297059, 0.6880555215708731, 0.001054400221076669, -0.42105124505999425, 2.336469174164463, -0.7417208502006818, 0.6985571925543341, -1.7697790316037931, -0.7541449896099406, 1.066612930948729, 0.6098458770176152, -0.5844492541420675, -0.7252780524014145, 1.6766741738266728, -2.303905758441713, 0.8763860789124235, -1.5313566852713132, -0.2803518944223377, 0.2010877036218884, -0.22417666391244997, 0.8970257408360885, 1.111498625778144, -0.49937156630115775, 0.4824795645549021, -0.8386873815365922, 0.8968358972389203, 0.18108957462874936, -0.7525670858222755, 1.0970819672029866, -0.6199806655971059, -1.9582556834611304, -0.477029765482458, 1.1151776916904006, -0.5261700345863449, 2.443249614890409, -0.5031115959292783, -1.3151737793461287, -0.32097520673417623, 0.18130418318607647, -1.2969804272233079, 0.3487633481461886, 3.708635426310603, -1.3229230885576984, -0.5045219532089735, 0.13743651932971618, 1.5922622370078203, 0.5567275740343535, 0.9109580367246067, 1.081021068654112, -0.14994226568230784, -1.592497816034335, 0.4771823339730135, -0.7389739093405032, -0.7442776526952183, -0.632874698032974, -0.252205287343342, -1.078300764419704, 0.2925680074082483, 1.0224448590848523, -0.5187057082824644, -0.9211149482463038, 1.0720860565869117, 0.44888754552071386, -0.4443553745933691, -1.61436839509682, -0.9816162867879763, -0.35516276060890883, 0.4966210263179557, -0.9431597229591258, 1.1115615860201136, 0.5373376089056937, -0.4659728973994573, -1.0136703695890588, 0.1601330896981051, -1.267162469521747, 0.4683393553187983, -1.667475951555503, -1.0208277083719042, 1.4774843809600096, -0.008790220826405834, 0.20366045245310962, 0.009841467181565254],), DynamicPPL.DefaultContext()))

Evaluating this model is threadsafe, in that Turing guarantees to provide the correct result in functions such as:

logjoint(model, (; x = 0.0))
-146.48329567145132

(we can compare with the true value)

logpdf(Normal(), 0.0) + sum(logpdf.(Normal(0.0), y))
-146.48329567145134

When sampling, you must disable model checking, but otherwise results will be correct:

sample(model, NUTS(), 100; check_model=false, progress=false)
Info: Found initial step size
  ϵ = 0.2
Chains MCMC chain (100×15×1 Array{Float64, 3}):

Iterations        = 51:1:150
Number of chains  = 1
Samples per chain = 100
Wall duration     = 6.85 seconds
Compute duration  = 6.85 seconds
parameters        = x
internals         = n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, lp, logprior, loglikelihood

Use `describe(chains)` for summary statistics and quantiles.
WarningUpcoming changes

Starting from DynamicPPL 0.39, if you use tilde-statements or @addlogprob! inside threaded blocks, you will have to declare this upfront using:

model = threaded_obs() | (; y = randn(N))
threadsafe_model = setthreadsafe(model, true)

Then you can sample from threadsafe_model as before.

The reason for this change is because threadsafe evaluation comes with a performance cost, which can sometimes be substantial. In the past, threadsafe evaluation was always enabled, i.e., this cost was always incurred whenever Julia was launched with more than one thread. However, this is not an appropriate way to determine whether threadsafe evaluation is needed!

On the other hand, parallelising the sampling of latent values is not supported. Attempting to do this will either error or give wrong results.

@model function threaded_assume_bad(N)
    x = Vector{Float64}(undef, N)
    Threads.@threads for i in 1:N
        x[i] ~ Normal()
    end
    return x
end

model = threaded_assume_bad(100)

# This will throw an error (and probably a different error
# each time it's run...)
model()
TaskFailedException

    nested task error: AssertionError: Multiple concurrent writes to Dict detected!
    Stacktrace:
      [1] rehash!(h::Dict{AbstractPPL.VarName, Int64}, newsz::Int64)
        @ Base ./dict.jl:182
      [2] _setindex!
        @ ./dict.jl:337 [inlined]
      [3] setindex!(h::Dict{AbstractPPL.VarName, Int64}, v0::Int64, key::AbstractPPL.VarName{:x, Accessors.IndexLens{Tuple{Int64}}})
        @ Base ./dict.jl:363
      [4] push!(meta::DynamicPPL.Metadata{Dict{AbstractPPL.VarName, Int64}, Vector{Distribution}, Vector{AbstractPPL.VarName}, Vector{Real}}, vn::AbstractPPL.VarName{:x, Accessors.IndexLens{Tuple{Int64}}}, r::Float64, dist::Normal{Float64})
        @ DynamicPPL ~/.julia/packages/DynamicPPL/Ut5Ls/src/varinfo.jl:1723
      [5] push!!
        @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/varinfo.jl:1735 [inlined]
      [6] push!!(vi::DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName, Int64}, Vector{Distribution}, Vector{AbstractPPL.VarName}, Vector{Real}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogJacobian::DynamicPPL.LogJacobianAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}}}}, vn::AbstractPPL.VarName{:x, Accessors.IndexLens{Tuple{Int64}}}, val::Float64, dist::Normal{Float64})
        @ DynamicPPL ~/.julia/packages/DynamicPPL/Ut5Ls/src/varinfo.jl:1671
      [7] push!!
        @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/threadsafe.jl:73 [inlined]
      [8] tilde_assume!!(ctx::DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}, dist::Normal{Float64}, vn::AbstractPPL.VarName{:x, Accessors.IndexLens{Tuple{Int64}}}, vi::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName, Int64}, Vector{Distribution}, Vector{AbstractPPL.VarName}, Vector{Real}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogJacobian::DynamicPPL.LogJacobianAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}}}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogJacobian::DynamicPPL.LogJacobianAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}}}})
        @ DynamicPPL ~/.julia/packages/DynamicPPL/Ut5Ls/src/contexts/init.jl:179
      [9] (::var"#57#threadsfor_fun#5"{var"#57#threadsfor_fun#4#6"{DynamicPPL.Model{typeof(threaded_assume_bad), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}}, UnitRange{Int64}}})(tid::Int64; onethread::Bool)
        @ Main.Notebook ./threadingconstructs.jl:253
     [10] #57#threadsfor_fun
        @ ./threadingconstructs.jl:220 [inlined]
     [11] (::Base.Threads.var"#1#2"{var"#57#threadsfor_fun#5"{var"#57#threadsfor_fun#4#6"{DynamicPPL.Model{typeof(threaded_assume_bad), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}}, UnitRange{Int64}}}, Int64})()
        @ Base.Threads ./threadingconstructs.jl:154

...and 3 more exceptions.

Stacktrace:
  [1] threading_run(fun::var"#57#threadsfor_fun#5"{var"#57#threadsfor_fun#4#6"{DynamicPPL.Model{typeof(threaded_assume_bad), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}}, UnitRange{Int64}}}, static::Bool)
    @ Base.Threads ./threadingconstructs.jl:173
  [2] macro expansion
    @ ./threadingconstructs.jl:190 [inlined]
  [3] threaded_assume_bad
    @ ~/work/docs/docs/usage/threadsafe-evaluation/index.qmd:126 [inlined]
  [4] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:974 [inlined]
  [5] evaluate_threadsafe!!(model::DynamicPPL.Model{typeof(threaded_assume_bad), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}}, varinfo::DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName, Int64}, Vector{Distribution}, Vector{AbstractPPL.VarName}, Vector{Real}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogJacobian::DynamicPPL.LogJacobianAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}}}})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:956
  [6] evaluate!!
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:923 [inlined]
  [7] init!!
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:900 [inlined]
  [8] init!!
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:899 [inlined]
  [9] Model
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:863 [inlined]
 [10] (::DynamicPPL.Model{typeof(threaded_assume_bad), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.DefaultContext})()
    @ DynamicPPL ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:856
 [11] top-level scope
    @ ~/work/docs/docs/usage/threadsafe-evaluation/index.qmd:136

Note, in particular, that this means that you cannot currently use predict to sample new data in parallel.

NoteThreaded predict

Support for threaded predict will be added in DynamicPPL 0.39 (see this pull request).

That is, even for threaded_obs where y was originally an observed term, you cannot do:

model = threaded_obs(N) | (; y = y)
chn = sample(model, NUTS(), 100; check_model=false, progress=false)

pmodel = threaded_obs(N)  # don't condition on data
predict(pmodel, chn)
Info: Found initial step size
  ϵ = 0.4
TaskFailedException

    nested task error: BoundsError: attempt to access 11-element BitVector at index [14]
    Stacktrace:
      [1] throw_boundserror(A::BitVector, I::Tuple{Int64})
        @ Base ./essentials.jl:14
      [2] checkbounds
        @ ./abstractarray.jl:699 [inlined]
      [3] getindex
        @ ./bitarray.jl:681 [inlined]
      [4] is_transformed
        @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/varinfo.jl:899 [inlined]
      [5] is_transformed
        @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/varinfo.jl:898 [inlined]
      [6] (::DynamicPPL.var"#123#124"{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName, Int64}, Vector{Distribution}, Vector{AbstractPPL.VarName}, Vector{Real}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}, ValuesAsInModel::DynamicPPL.ValuesAsInModelAccumulator}}}})(vn::AbstractPPL.VarName{:y, Accessors.IndexLens{Tuple{Int64}}})
        @ DynamicPPL ./none:0
      [7] iterate
        @ ./generator.jl:48 [inlined]
      [8] _any(f::typeof(identity), itr::Base.Generator{Vector{AbstractPPL.VarName}, DynamicPPL.var"#123#124"{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName, Int64}, Vector{Distribution}, Vector{AbstractPPL.VarName}, Vector{Real}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}, ValuesAsInModel::DynamicPPL.ValuesAsInModelAccumulator}}}}}, ::Colon)
        @ Base ./reduce.jl:1243
      [9] any
        @ ./reduce.jl:1228 [inlined]
     [10] any
        @ ./reduce.jl:1154 [inlined]
     [11] is_transformed
        @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/varinfo.jl:1499 [inlined]
     [12] is_transformed
        @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/threadsafe.jl:83 [inlined]
     [13] tilde_assume!!(ctx::DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}, dist::Normal{Float64}, vn::AbstractPPL.VarName{:y, Accessors.IndexLens{Tuple{Int64}}}, vi::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName, Int64}, Vector{Distribution}, Vector{AbstractPPL.VarName}, Vector{Real}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}, ValuesAsInModel::DynamicPPL.ValuesAsInModelAccumulator}}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}, ValuesAsInModel::DynamicPPL.ValuesAsInModelAccumulator}}})
        @ DynamicPPL ~/.julia/packages/DynamicPPL/Ut5Ls/src/contexts/init.jl:168
     [14] (::var"#4#threadsfor_fun#2"{var"#4#threadsfor_fun#1#3"{DynamicPPL.Model{typeof(threaded_obs), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}}, UnitRange{Int64}}})(tid::Int64; onethread::Bool)
        @ Main.Notebook ./threadingconstructs.jl:253
     [15] #4#threadsfor_fun
        @ ./threadingconstructs.jl:220 [inlined]
     [16] (::Base.Threads.var"#1#2"{var"#4#threadsfor_fun#2"{var"#4#threadsfor_fun#1#3"{DynamicPPL.Model{typeof(threaded_obs), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}}, UnitRange{Int64}}}, Int64})()
        @ Base.Threads ./threadingconstructs.jl:154

...and 3 more exceptions.

Stacktrace:
  [1] threading_run(fun::var"#4#threadsfor_fun#2"{var"#4#threadsfor_fun#1#3"{DynamicPPL.Model{typeof(threaded_obs), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}}, UnitRange{Int64}}}, static::Bool)
    @ Base.Threads ./threadingconstructs.jl:173
  [2] macro expansion
    @ ./threadingconstructs.jl:190 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/compiler.jl:573 [inlined]
  [4] threaded_obs(__model__::DynamicPPL.Model{typeof(threaded_obs), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName, Int64}, Vector{Distribution}, Vector{AbstractPPL.VarName}, Vector{Real}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}, ValuesAsInModel::DynamicPPL.ValuesAsInModelAccumulator}}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}, ValuesAsInModel::DynamicPPL.ValuesAsInModelAccumulator}}}, N::Int64)
    @ Main.Notebook ~/work/docs/docs/usage/threadsafe-evaluation/index.qmd:74
  [5] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:974 [inlined]
  [6] evaluate_threadsafe!!(model::DynamicPPL.Model{typeof(threaded_obs), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}}, varinfo::DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName, Int64}, Vector{Distribution}, Vector{AbstractPPL.VarName}, Vector{Real}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}, ValuesAsInModel::DynamicPPL.ValuesAsInModelAccumulator}}})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:956
  [7] evaluate!!
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:923 [inlined]
  [8] init!!
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:900 [inlined]
  [9] init!!
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:907 [inlined]
 [10] init!!(model::DynamicPPL.Model{typeof(threaded_obs), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.DefaultContext}, varinfo::DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName, Int64}, Vector{Distribution}, Vector{AbstractPPL.VarName}, Vector{Real}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}, ValuesAsInModel::DynamicPPL.ValuesAsInModelAccumulator}}})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/Ut5Ls/src/model.jl:907
 [11] predict(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{typeof(threaded_obs), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.DefaultContext}, chain::Chains{Float64, AxisArrays.AxisArray{Float64, 3, Array{Float64, 3}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, @NamedTuple{parameters::Vector{Symbol}, internals::Vector{Symbol}}, @NamedTuple{varname_to_symbol::OrderedDict{AbstractPPL.VarName, Symbol}, start_time::Float64, stop_time::Float64}}; include_all::Bool)
    @ DynamicPPLMCMCChainsExt ~/.julia/packages/DynamicPPL/Ut5Ls/ext/DynamicPPLMCMCChainsExt.jl:221
 [12] predict
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/ext/DynamicPPLMCMCChainsExt.jl:204 [inlined]
 [13] #predict#18
    @ ~/.julia/packages/DynamicPPL/Ut5Ls/ext/DynamicPPLMCMCChainsExt.jl:248 [inlined]
 [14] predict(model::DynamicPPL.Model{typeof(threaded_obs), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.DefaultContext}, chain::Chains{Float64, AxisArrays.AxisArray{Float64, 3, Array{Float64, 3}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, @NamedTuple{parameters::Vector{Symbol}, internals::Vector{Symbol}}, @NamedTuple{varname_to_symbol::OrderedDict{AbstractPPL.VarName, Symbol}, start_time::Float64, stop_time::Float64}})
    @ DynamicPPLMCMCChainsExt ~/.julia/packages/DynamicPPL/Ut5Ls/ext/DynamicPPLMCMCChainsExt.jl:245
 [15] top-level scope
    @ ~/work/docs/docs/usage/threadsafe-evaluation/index.qmd:155

Alternatives to threaded observation

An alternative to using threaded observations is to manually calculate the log-likelihood term (which can be parallelised using any of Julia’s standard mechanisms), and then outside of the threaded block, add it to the model using @addlogprob!.

For example:

# Note that `y` has to be passed as an argument; you can't
# condition on it because otherwise `y[i]` won't be defined.
@model function threaded_obs_addlogprob(N, y)
    x ~ Normal()

    # Instead of this:
    # Threads.@threads for i in 1:N
    #     y[i] ~ Normal(x)
    # end

    # Do this instead:
    lls = map(1:N) do i
        Threads.@spawn begin
            logpdf(Normal(x), y[i])
        end
    end
    @addlogprob! sum(fetch.(lls))
end
threaded_obs_addlogprob (generic function with 2 methods)

In a similar way, you can also use your favourite parallelism package, such as FLoops.jl or OhMyThreads.jl. See this Discourse post for some examples.

We make no promises about the use of tilde-statements with these packages (indeed it will most likely error), but as long as you use them to only parallelise regular Julia code (i.e., not tilde-statements), they will work as intended.

The main downside of this approach is:

  1. You can’t use conditioning syntax to provide data; it has to be passed as an argument or otherwise included inside the model.
  2. You can’t use predict to sample new data.

On the other hand, one benefit of rewriting the model this way is that sampling from this model with MCMCThreads() will always be reproducible.

using Random
N = 100
y = randn(N)
model = threaded_obs_addlogprob(N, y)
nuts_kwargs = (check_model=false, progress=false, verbose=false)

chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
mean(chain1[:x]), mean(chain2[:x])  # should be identical
(0.04022869247908463, 0.04022869247908463)

In contrast, the original threaded_obs (which used tilde inside Threads.@threads) is not reproducible when using MCMCThreads(). (In principle, we would like to fix this bug, but we haven’t yet investigated where it stems from.)

model = threaded_obs(N) | (; y = y)
nuts_kwargs = (check_model=false, progress=false, verbose=false)
chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
mean(chain1[:x]), mean(chain2[:x])  # oops!
(0.03942747774797233, 0.03567888986946416)

AD support

Finally, if you are using Turing with automatic differentiation, you also need to keep track of which AD backends support threadsafe evaluation.

ForwardDiff is the only AD backend that we find to work reliably with threaded model evaluation.

In particular:

Under the hood

Note

This part will likely only be of interest to DynamicPPL developers and the very curious user.

Why is VarInfo not threadsafe?

As alluded to above, the issue with threaded tilde-statements stems from the fact that these tilde-statements modify the VarInfo object used for model evaluation, leading to potential data races.

Traditionally, VarInfo objects contain both metadata as well as accumulators. Metadata is where information about the random variables’ values are stored. It is a Dict-like structure, and pushing to it from multiple threads is therefore not threadsafe (Julia’s Dict has similar limitations).

On the other hand, accumulators are used to store outputs of the model, such as log-probabilities The way DynamicPPL’s threadsafe evaluation works is to create one set of accumulators per thread, and then combine the results at the end of model evaluation.

In this way, any function call that solely involving accumulators can be made threadsafe. For example, this is why observations are supported: there is no need to modify metadata, and only the log-likelihood accumulator needs to be updated.

However, assume tilde-statements always modify the metadata, and thus cannot currently be made threadsafe.

OnlyAccsVarInfo

As it happens, much of what is needed in DynamicPPL can be constructed such that they only rely on accumulators.

For example, as long as there is no need to sample new values of random variables, it is actually fine to completely omit the metadata object. This is the case for LogDensityFunction: since values are provided as the input vector, there is no need to store it in metadata. We need only calculate the associated log-prior probability, which is stored in an accumulator. Thus, starting from DynamicPPL v0.39, LogDensityFunction itself will in fact be completely threadsafe.

Technically speaking, this is achieved using OnlyAccsVarInfo, which is a subtype of VarInfo that only contains accumulators, and no metadata at all. It implements enough of the VarInfo interface to be used in model evaluation, but will error if any functions attempt to modify or read its metadata.

There is currently an ongoing push to use OnlyAccsVarInfo in as many settings as we possibly can. For example, this is why predict will be threadsafe in DynamicPPL v0.39: instead of modifying metadata to store the predicted values, we store them inside a ValuesAsInModelAccumulator instead, and combine them at the end of evaluation.

However, propagating these changes up to Turing will require a substantial amount of additional work, since there are many places in Turing which currently rely on a full VarInfo (with metadata). See, e.g., this PR for more information.

Back to top