println("This notebook is being run with $(Threads.nthreads()) threads.")This notebook is being run with 4 threads.
A common technique to speed up Julia code is to use multiple threads to run computations in parallel. The Julia manual has a section on multithreading, which is a good introduction to the topic.
We assume that the reader is familiar with some threading constructs in Julia, and the general concept of data races. This page specificaly discusses Turing’s support for threadsafe model evaluation.
This notebook is being run with 4 threads.
To a first approximation, Turing completely supports multithreaded code inside models.
For example, you can use Threads.@threads to parallelise ‘ordinary’ Julia code inside a model. Here is an example of parallelising some expensive computation inside a model:
[ Info: [Turing]: progress logging is disabled globally ┌ Warning: It looks like you are using `Threads.@threads` in your model definition. │ │ Note that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default. If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`. │ │ Threadsafe model evaluation is only needed when parallelising tilde-statements (not arbitrary Julia code), and avoiding it can often lead to significant performance improvements. │ │ Please see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required. └ @ DynamicPPL ~/.julia/packages/DynamicPPL/pqVkv/src/compiler.jl:357
parallel (generic function with 2 methods)
An example like the above, where the parallelisation is separate from the modelling syntax (i.e., tilde-statements), will work without any special considerations.
However, extra care must be taken when using tilde-statements (x ~ dist), or @addlogprob!, inside threaded blocks. Specifically, if you do this, you must mark the model as requiring threadsafe evaluation, using setthreadsafe. For example:
┌ Warning: It looks like you are using `Threads.@threads` in your model definition. │ │ Note that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default. If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`. │ │ Threadsafe model evaluation is only needed when parallelising tilde-statements (not arbitrary Julia code), and avoiding it can often lead to significant performance improvements. │ │ Please see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required. └ @ DynamicPPL ~/.julia/packages/DynamicPPL/pqVkv/src/compiler.jl:357
DynamicPPL.Model{typeof(threaded), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.CondFixContext{DynamicPPL.Condition, VarNamedTuple{(:y,), Tuple{Vector{Float64}}}, DynamicPPL.DefaultContext}, true}(threaded, (N = 20,), NamedTuple(), CondFixContext{DynamicPPL.Condition}(VarNamedTuple(y = [-0.14059620812490148, 0.26352305218039623, 0.6036508073517312, -0.39827757594870167, 2.1248963644114207, -0.23847524181107085, -0.24105549651319977, 0.7497761210486046, 0.5108398379873107, -1.775504295607224, 2.054292316177274, 0.25484577513109, -0.894018850307107, -0.13218665716527775, -1.2967804895479091, 2.621254699899919, -0.9362202042047076, -0.695355165552344, -0.44086223384179357, -2.275135753395398],), DynamicPPL.DefaultContext()))
Tilde-statements are expanded by the @model macro into something that modifies the internal AbstractVarInfo object used during model evaluation. Essentially, x ~ dist expands to something like
and writing into __abstractvarinfo__ is, in general, not threadsafe. Thus, parallelising tilde-statements can lead to data races as described in the Julia manual.
Turing’s threadsafe flag works by creating one AbstractVarInfo per thread, and then combining the results at the end of model evaluation.
Once the model has been marked as threadsafe, Turing guarantees to provide the correct result in functions such as:
(we can compare with the true value)
Note that if you do not use setthreadsafe, the above code may give wrong results, or even error:
You can sample from this model and safely use functions such as predict or returned, as long as the model is always marked as threadsafe:
┌ Info: Found initial step size └ ϵ = 1.6
╭─FlexiChain (100 iterations, 1 chain) ────────────────────────────────────────╮ │ ↓ iter = 51:150 │ │ → chain = 1:1 │ │ │ │ Parameters (1) ── AbstractPPL.VarName │ │ Vector{Float64} x │ │ │ │ Extras (14) │ │ Int64 n_steps, tree_depth │ │ Bool is_accept, numerical_error │ │ Float64 acceptance_rate, log_density, hamiltonian_energy, │ │ hamiltonian_energy_error, max_hamiltonian_energy_error, step_size, │ │ nom_step_size, logprior, loglikelihood, logjoint │ ╰──────────────────────────────────────────────────────────────────────────────╯
╭─FlexiChain (100 iterations, 1 chain) ────────────────────────────────────────╮ │ ↓ iter = 51:150 │ │ → chain = 1:1 │ │ │ │ Parameters (2) ── AbstractPPL.VarName │ │ Vector{Float64} x, y │ │ │ │ Extras (14) │ │ Int64 n_steps, tree_depth │ │ Bool is_accept, numerical_error │ │ Float64 acceptance_rate, log_density, hamiltonian_energy, │ │ hamiltonian_energy_error, max_hamiltonian_energy_error, step_size, │ │ nom_step_size, logprior, loglikelihood, logjoint │ ╰──────────────────────────────────────────────────────────────────────────────╯
Up until Turing v0.41, you did not need to use setthreadsafe to enable threadsafe evaluation, and it was automatically enabled whenever Julia was launched with more than one thread.
There were several reasons for changing this: one major one is because threadsafe evaluation comes with a performance cost, which can sometimes be substantial (see below).
Furthermore, the number of threads is not an appropriate way to determine whether threadsafe evaluation is needed!
Note that, due to reasons which we do not yet fully understand (but likely relate to race conditions in the mutation of the random number generator), the use of threadsafe evaluation is not always fully deterministic when assume-statements, i.e. random variables, are parallelised.
In the model above, the x[i]’s are random variables since they are on the left-hand side of a tilde-statement but are not conditioned on. In contrast, the y[i]’s are data, not a random variable.
This means if your model contains parallelised random variables, you are not guaranteed to get the same results every time, even if you set the random seed:
mean(chn[#= /home/runner/work/docs/docs/usage/threadsafe-evaluation/index.qmd:138 =# @varname(x[1])]) = 0.04030210823718552
mean(chn[#= /home/runner/work/docs/docs/usage/threadsafe-evaluation/index.qmd:141 =# @varname(x[1])]) = -0.0748048704794504
Some samplers do indeed yield the same results (but NUTS is not one of them, and we cannot make any concrete guarantees at this point in time):
mean(chn[#= /home/runner/work/docs/docs/usage/threadsafe-evaluation/index.qmd:148 =# @varname(x[1])]) = -0.2719257503311838
mean(chn[#= /home/runner/work/docs/docs/usage/threadsafe-evaluation/index.qmd:151 =# @varname(x[1])]) = -0.6054794003783021
Now consider a different situation where you only have parallelised data, and not random variables. In this case we do guarantee that sampling is fully deterministic:
@model function threaded_data(N)
x ~ Normal()
y = Vector{Float64}(undef, N)
Threads.@threads for i in 1:N
y[i] ~ Normal(x)
end
end
threadsafe_model_data_only = setthreadsafe(threaded_data(N) | (; y = y), true)
chn = sample(Xoshiro(468), threadsafe_model_data_only, NUTS(), 100; verbose=false)
@show mean(chn[@varname(x)])
chn = sample(Xoshiro(468), threadsafe_model_data_only, NUTS(), 100; verbose=false)
@show mean(chn[@varname(x)]);┌ Warning: It looks like you are using `Threads.@threads` in your model definition. │ │ Note that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default. If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`. │ │ Threadsafe model evaluation is only needed when parallelising tilde-statements (not arbitrary Julia code), and avoiding it can often lead to significant performance improvements. │ │ Please see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required. └ @ DynamicPPL ~/.julia/packages/DynamicPPL/pqVkv/src/compiler.jl:357 mean(chn[#= /home/runner/work/docs/docs/usage/threadsafe-evaluation/index.qmd:168 =# @varname(x)]) = 0.04874095875418817 mean(chn[#= /home/runner/work/docs/docs/usage/threadsafe-evaluation/index.qmd:171 =# @varname(x)]) = 0.04874095875406785
You only need to enable threadsafe evaluation if you are using tilde-statements or @addlogprob! inside threaded blocks.
Specifically, you do not need to enable threadsafe evaluation if:
You have parallelism inside the model, but it does not involve tilde-statements or @addlogprob!.
You are sampling from a model using MCMCThreads(), but the model itself does not contain any parallel tilde-statements or @addlogprob!.
As described above, one of the major considerations behind the introduction of setthreadsafe is that threadsafe evaluation comes with a performance cost.
Consider a simple model that does not use threading:
DynamicPPL.Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext, true}(gdemo, NamedTuple(), NamedTuple(), DynamicPPL.DefaultContext())
One can see that evaluation of the threadsafe model is substantially slower:
33.550 ns
294.719 ns (13 allocs: 528 bytes)
In previous versions of Turing, this cost would always be incurred whenever Julia was launched with multiple threads, even if the model did not use any threading at all!
Finally, if you are using Turing with automatic differentiation, you also need to keep track of which AD backends support threadsafe evaluation.
ForwardDiff and Enzyme are the only AD backends that we find to work reliably with threaded model evaluation. Note that for Enzyme, you should use a relatively recent version (at least v0.13.140) as prior to that reverse-mode could yield incorrect results.
In contrast, ReverseDiff sometimes gives right results, but quite often gives incorrect gradients. Mooncake currently does not support multithreading at all.
For more details you can take a look at the threaded_... models on ADTests.
This part will likely only be of interest to DynamicPPL developers and the very curious user.
Code in DynamicPPL that uses VarInfo is not threadsafe in general. For any code that uses VarInfo, observe statements are threadsafe, but assume statements are not.
In contrast, code that uses OnlyAccsVarInfo is completely threadsafe.
Now, virtually all of DynamicPPL and Turing use OnlyAccsVarInfo, this means that most of DynamicPPL and Turing is threadsafe. You only need to worry about edge cases if you are still using VarInfo directly.
As alluded to above, the issue with threaded tilde-statements stems from the fact that these tilde-statements modify the VarInfo object used for model evaluation, leading to potential data races.
Traditionally, VarInfo objects contain both metadata as well as accumulators. Metadata is where information about the random variables’ values are stored. It is a Dict-like structure, and pushing to it from multiple threads is therefore not threadsafe (Julia’s Dict has similar limitations).
On the other hand, accumulators are used to store outputs of the model, such as log-probabilities The way DynamicPPL’s threadsafe evaluation works is to create one set of accumulators per thread, and then combine the results at the end of model evaluation.
In this way, any function call that solely involving accumulators can be made threadsafe. For example, this is why observations are supported: there is no need to modify metadata, and only the log-likelihood accumulator needs to be updated.
However, assume tilde-statements always modify the metadata, and thus cannot currently be made threadsafe.
As it happens, much of what is needed in DynamicPPL can be constructed such that they only rely on accumulators.
For example, as long as there is no need to sample new values of random variables, it is actually fine to completely omit the metadata object. This is the case for LogDensityFunction: since values are provided as the input vector, there is no need to store it in metadata. We need only calculate the associated log-prior probability, which is stored in an accumulator. Thus, since DynamicPPL v0.39, LogDensityFunction itself is completely threadsafe.
Technically speaking, this is achieved using OnlyAccsVarInfo, which is a subtype of VarInfo that only contains accumulators, and no metadata at all. It implements enough of the VarInfo interface to be used in model evaluation, but will error if any functions attempt to modify or read its metadata.