println("This notebook is being run with $(Threads.nthreads()) threads.")This notebook is being run with 4 threads.
A common technique to speed up Julia code is to use multiple threads to run computations in parallel. The Julia manual has a section on multithreading, which is a good introduction to the topic.
We assume that the reader is familiar with some threading constructs in Julia, and the general concept of data races. This page specificaly discusses Turing’s support for threadsafe model evaluation.
This notebook is being run with 4 threads.
Given that Turing models mostly contain ‘plain’ Julia code, one might expect that all threading constructs such as Threads.@threads or Threads.@spawn can be used inside Turing models.
This is, to some extent, true: for example, you can use threading constructs to speed up deterministic computations. For example, here we use parallelism to speed up a transformation of x:
┌ Warning: It looks like you are using `Threads.@threads` in your model definition. │ │ Note that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default. If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`. │ │ Threadsafe model evaluation is only needed when parallelising tilde-statements (not arbitrary Julia code), and avoiding it can often lead to significant performance improvements. │ │ Please see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required. └ @ DynamicPPL ~/.julia/packages/DynamicPPL/Hza15/src/compiler.jl:383
parallel (generic function with 2 methods)
In general, for code that does not involve tilde-statements (x ~ dist), threading works exactly as it does in regular Julia code.
However, extra care must be taken when using tilde-statements (x ~ dist), or @addlogprob!, inside threaded blocks.
Tilde-statements are expanded by the @model macro into something that modifies the internal VarInfo object used for model evaluation. Essentially, x ~ dist expands to something like
and writing into __varinfo__ is, in general, not threadsafe. Thus, parallelising tilde-statements can lead to data races as described in the Julia manual.
As of version 0.42, Turing only supports the use of tilde-statements inside threaded blocks when these are observations (i.e., likelihood terms).
However, such models must be marked by the user as requiring threadsafe evaluation, using setthreadsafe.
This means that the following code is safe to use:
┌ Warning: It looks like you are using `Threads.@threads` in your model definition. │ │ Note that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default. If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`. │ │ Threadsafe model evaluation is only needed when parallelising tilde-statements (not arbitrary Julia code), and avoiding it can often lead to significant performance improvements. │ │ Please see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required. └ @ DynamicPPL ~/.julia/packages/DynamicPPL/Hza15/src/compiler.jl:383
DynamicPPL.Model{typeof(threaded_obs), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.ConditionContext{@NamedTuple{y::Vector{Float64}}, DynamicPPL.DefaultContext}, true}(threaded_obs, (N = 100,), NamedTuple(), ConditionContext((y = [0.06134293596462265, 1.7818006913112332, 1.0813935903220444, -0.6606194360419726, -1.0460877464213474, 1.3858839677750066, -0.4634214335612833, -1.1836077304728008, -0.5055882419832128, 0.05507754976766605, -1.4924827149509021, -1.0886099942837364, 0.2598174546411505, 0.42343390918012025, 0.13580041307836394, -0.7142356299847703, 2.005554709222987, -1.0763596500992225, 0.07763525477542918, -0.6794776906982068, 0.988366770444963, -0.07210227919002261, -1.020122233678636, -0.7821491672670942, -0.5239140452697657, -0.07888594289690391, -1.293432822562895, -0.6348465012963322, 0.5763115589815008, 0.14084035677732165, 0.6924493190825315, 0.12658030070018123, 0.10712578551765896, 0.654520998081005, 0.018495761424209493, 1.6643710147836621, -0.34833906269280385, 1.1457592979475195, 0.24586637291596078, 0.8456469371700077, -2.4433467304963, 0.13854410287386876, 0.3873493449077946, 2.5816969665235803, 1.1646245214661362, 1.4316078597489645, -0.4088411306768671, -0.6526642184260992, 1.8982409155709146, 1.0783138390453246, 0.690354292610506, -0.19878565764670017, -0.47510919713980776, 1.0217339453821457, -1.4783884595349672, -1.0037335167214543, -0.5117257738993258, -0.4595117048910306, -0.604297826389685, 0.47294761234196986, 0.22950919450473248, 0.4134804265090527, -0.9435380939081741, -1.186678273076102, 1.1936260198453474, -0.4445994506342458, 0.07858531771834427, -0.49582368330669696, 1.2709001820476704, -1.7747455419002476, -0.4619810620301185, -0.8942584744638183, 1.0084375851801417, -2.4402132096627556, 0.5302162324032571, -0.6529850943222767, 2.251490265007009, -1.1814848905885744, -0.306233716412341, 0.35301904840245896, -0.5809575845849657, 0.16631856744707382, -0.16794958377000999, 0.444201789656806, -0.6728810302237078, -0.43902289682695794, -0.5762965203789076, -0.9020932212860142, -0.5058524828921526, 1.6078960852936743, 0.73244793421824, -0.038067451123256164, 1.0346698810311292, -0.8159228941394348, -1.1000350787194466, -1.031183619469035, 0.5572923921179068, 0.9071932061111072, -0.3027320866976259, 0.46661548816476306],), DynamicPPL.DefaultContext()))
Evaluating this model is threadsafe, in that Turing guarantees to provide the correct result in functions such as:
(we can compare with the true value)
Note that if you do not use setthreadsafe, the above code may give wrong results, or even error:
You can sample from this model and safely use functions such as predict or returned, as long as the model is always marked as threadsafe:
┌ Info: Found initial step size └ ϵ = 0.4
Chains MCMC chain (100×15×1 Array{Float64, 3}):
Iterations = 51:1:150
Number of chains = 1
Samples per chain = 100
Wall duration = 6.17 seconds
Compute duration = 6.17 seconds
parameters = x
internals = n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, logprior, loglikelihood, logjoint
Use `describe(chains)` for summary statistics and quantiles.
Chains MCMC chain (100×100×1 Array{Float64, 3}):
Iterations = 1:1:100
Number of chains = 1
Samples per chain = 100
parameters = y[1], y[2], y[3], y[4], y[5], y[6], y[7], y[8], y[9], y[10], y[11], y[12], y[13], y[14], y[15], y[16], y[17], y[18], y[19], y[20], y[21], y[22], y[23], y[24], y[25], y[76], y[77], y[78], y[79], y[80], y[81], y[82], y[83], y[84], y[85], y[86], y[87], y[88], y[89], y[90], y[91], y[92], y[93], y[94], y[95], y[96], y[97], y[98], y[99], y[100], y[51], y[52], y[53], y[54], y[55], y[56], y[57], y[58], y[59], y[60], y[61], y[62], y[63], y[64], y[65], y[66], y[67], y[68], y[69], y[70], y[71], y[72], y[73], y[74], y[75], y[26], y[27], y[28], y[29], y[30], y[31], y[32], y[33], y[34], y[35], y[36], y[37], y[38], y[39], y[40], y[41], y[42], y[43], y[44], y[45], y[46], y[47], y[48], y[49], y[50]
internals =
Use `describe(chains)` for summary statistics and quantiles.
Up until Turing v0.41, you did not need to use setthreadsafe to enable threadsafe evaluation, and it was automatically enabled whenever Julia was launched with more than one thread.
There were several reasons for changing this: one major one is because threadsafe evaluation comes with a performance cost, which can sometimes be substantial (see below).
Furthermore, the number of threads is not an appropriate way to determine whether threadsafe evaluation is needed!
On the other hand, parallelising the sampling of latent values is not supported. Attempting to do this will either error or give wrong results.
┌ Warning: It looks like you are using `Threads.@threads` in your model definition. │ │ Note that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default. If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`. │ │ Threadsafe model evaluation is only needed when parallelising tilde-statements (not arbitrary Julia code), and avoiding it can often lead to significant performance improvements. │ │ Please see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required. └ @ DynamicPPL ~/.julia/packages/DynamicPPL/Hza15/src/compiler.jl:383
TaskFailedException nested task error: UndefRefError: access to undefined reference Stacktrace: [1] getindex @ ./essentials.jl:917 [inlined] [2] iterate @ ./array.jl:902 [inlined] [3] iterate @ ./generator.jl:45 [inlined] [4] _any(f::typeof(identity), itr::Base.Generator{Vector{AbstractPPL.VarName}, DynamicPPL.var"#123#124"{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName, Int64}, Vector{Distribution}, Vector{AbstractPPL.VarName}, Vector{Real}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogJacobian::DynamicPPL.LogJacobianAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}}}}}}, ::Colon) @ Base ./reduce.jl:1243 [5] any @ ./reduce.jl:1228 [inlined] [6] any @ ./reduce.jl:1154 [inlined] [7] is_transformed @ ~/.julia/packages/DynamicPPL/Hza15/src/varinfo.jl:1485 [inlined] [8] tilde_assume!!(ctx::DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}, dist::Normal{Float64}, vn::AbstractPPL.VarName{:x, Accessors.IndexLens{Tuple{Int64}}}, vi::DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName, Int64}, Vector{Distribution}, Vector{AbstractPPL.VarName}, Vector{Real}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::DynamicPPL.LogPriorAccumulator{Float64}, LogJacobian::DynamicPPL.LogJacobianAccumulator{Float64}, LogLikelihood::DynamicPPL.LogLikelihoodAccumulator{Float64}}}}) @ DynamicPPL ~/.julia/packages/DynamicPPL/Hza15/src/contexts/init.jl:329 [9] (::var"#88#threadsfor_fun#8"{var"#88#threadsfor_fun#7#9"{DynamicPPL.Model{typeof(threaded_assume_bad), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}, false}, UnitRange{Int64}}})(tid::Int64; onethread::Bool) @ Main.Notebook ./threadingconstructs.jl:253 [10] #88#threadsfor_fun @ ./threadingconstructs.jl:220 [inlined] [11] (::Base.Threads.var"#1#2"{var"#88#threadsfor_fun#8"{var"#88#threadsfor_fun#7#9"{DynamicPPL.Model{typeof(threaded_assume_bad), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}, false}, UnitRange{Int64}}}, Int64})() @ Base.Threads ./threadingconstructs.jl:154 ...and 3 more exceptions. Stacktrace: [1] threading_run(fun::var"#88#threadsfor_fun#8"{var"#88#threadsfor_fun#7#9"{DynamicPPL.Model{typeof(threaded_assume_bad), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.InitContext{Random.TaskLocalRNG, InitFromPrior}, false}, UnitRange{Int64}}}, static::Bool) @ Base.Threads ./threadingconstructs.jl:173 [2] macro expansion @ ./threadingconstructs.jl:190 [inlined] [3] threaded_assume_bad @ ~/work/docs/docs/usage/threadsafe-evaluation/index.qmd:140 [inlined] [4] _evaluate!! @ ~/.julia/packages/DynamicPPL/Hza15/src/model.jl:997 [inlined] [5] evaluate!! @ ~/.julia/packages/DynamicPPL/Hza15/src/model.jl:983 [inlined] [6] init!! @ ~/.julia/packages/DynamicPPL/Hza15/src/model.jl:938 [inlined] [7] init!! @ ~/.julia/packages/DynamicPPL/Hza15/src/model.jl:936 [inlined] [8] Model @ ~/.julia/packages/DynamicPPL/Hza15/src/model.jl:911 [inlined] [9] (::DynamicPPL.Model{typeof(threaded_assume_bad), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.DefaultContext, false})() @ DynamicPPL ~/.julia/packages/DynamicPPL/Hza15/src/model.jl:904 [10] top-level scope @ ~/work/docs/docs/usage/threadsafe-evaluation/index.qmd:150
You only need to enable threadsafe evaluation if you are using tilde-statements or @addlogprob! inside threaded blocks.
Specifically, you do not need to enable threadsafe evaluation if:
You have parallelism inside the model, but it does not involve tilde-statements or @addlogprob!.
You are sampling from a model using MCMCThreads(), but the model itself does not contain any parallel tilde-statements or @addlogprob!.
As described above, one of the major considerations behind the introduction of setthreadsafe is that threadsafe evaluation comes with a performance cost.
Consider a simple model that does not use threading:
DynamicPPL.Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext, true}(gdemo, NamedTuple(), NamedTuple(), DynamicPPL.DefaultContext())
One can see that evaluation of the threadsafe model is substantially slower:
286.814 ns (8 allocs: 464 bytes)
3.361 μs (49 allocs: 2.766 KiB)
In previous versions of Turing, this cost would always be incurred whenever Julia was launched with multiple threads, even if the model did not use any threading at all!
An alternative to using threaded observations is to manually calculate the log-likelihood term (which can be parallelised using any of Julia’s standard mechanisms), and then outside of the threaded block, add it to the model using @addlogprob!.
For example:
# Note that `y` has to be passed as an argument; you can't
# condition on it because otherwise `y[i]` won't be defined.
@model function threaded_obs_addlogprob(N, y)
x ~ Normal()
# Instead of this:
# Threads.@threads for i in 1:N
# y[i] ~ Normal(x)
# end
# Do this instead:
lls = map(1:N) do i
Threads.@spawn begin
logpdf(Normal(x), y[i])
end
end
@addlogprob! sum(fetch.(lls))
endthreaded_obs_addlogprob (generic function with 2 methods)
In a similar way, you can also use your favourite parallelism package, such as FLoops.jl or OhMyThreads.jl. See this Discourse post for some examples.
We make no promises about the use of tilde-statements with these packages (indeed it will most likely error), but as long as you use them to only parallelise regular Julia code (i.e., not tilde-statements), they will work as intended.
The main downside of this approach is:
predict to sample new data.On the other hand, one benefit of rewriting the model this way is that sampling from this model with MCMCThreads() will always be reproducible.
using Random
N = 100
y = randn(N)
# Note that since `@addlogprob!` is outside of the threaded block, we don't
# need to use `setthreadsafe`.
model = threaded_obs_addlogprob(N, y)
nuts_kwargs = (progress=false, verbose=false)
chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
mean(chain1[:x]), mean(chain2[:x]) # should be identical(-0.14560885471493212, -0.14560885471493212)
In contrast, the original threaded_obs (which used tilde inside Threads.@threads) is not reproducible when using MCMCThreads(). (In principle, we would like to fix this bug, but we haven’t yet investigated where it stems from.)
model = setthreadsafe(threaded_obs(N) | (; y = y), true)
nuts_kwargs = (progress=false, verbose=false)
chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
mean(chain1[:x]), mean(chain2[:x]) # oops!(-0.14163600098373222, -0.1459215230651795)
Finally, if you are using Turing with automatic differentiation, you also need to keep track of which AD backends support threadsafe evaluation.
ForwardDiff is the only AD backend that we find to work reliably with threaded model evaluation.
In particular:
This part will likely only be of interest to DynamicPPL developers and the very curious user.
As alluded to above, the issue with threaded tilde-statements stems from the fact that these tilde-statements modify the VarInfo object used for model evaluation, leading to potential data races.
Traditionally, VarInfo objects contain both metadata as well as accumulators. Metadata is where information about the random variables’ values are stored. It is a Dict-like structure, and pushing to it from multiple threads is therefore not threadsafe (Julia’s Dict has similar limitations).
On the other hand, accumulators are used to store outputs of the model, such as log-probabilities The way DynamicPPL’s threadsafe evaluation works is to create one set of accumulators per thread, and then combine the results at the end of model evaluation.
In this way, any function call that solely involving accumulators can be made threadsafe. For example, this is why observations are supported: there is no need to modify metadata, and only the log-likelihood accumulator needs to be updated.
However, assume tilde-statements always modify the metadata, and thus cannot currently be made threadsafe.
As it happens, much of what is needed in DynamicPPL can be constructed such that they only rely on accumulators.
For example, as long as there is no need to sample new values of random variables, it is actually fine to completely omit the metadata object. This is the case for LogDensityFunction: since values are provided as the input vector, there is no need to store it in metadata. We need only calculate the associated log-prior probability, which is stored in an accumulator. Thus, since DynamicPPL v0.39, LogDensityFunction itself is completely threadsafe.
Technically speaking, this is achieved using OnlyAccsVarInfo, which is a subtype of VarInfo that only contains accumulators, and no metadata at all. It implements enough of the VarInfo interface to be used in model evaluation, but will error if any functions attempt to modify or read its metadata.
There is currently an ongoing push to use OnlyAccsVarInfo in as many settings as we possibly can. For example, this is why predict is threadsafe in DynamicPPL v0.39: instead of modifying metadata to store the predicted values, we store them inside a ValuesAsInModelAccumulator instead, and combine them at the end of evaluation.
However, propagating these changes up to Turing will require a substantial amount of additional work, since there are many places in Turing which currently rely on a full VarInfo (with metadata). See, e.g., this PR for more information.