Performance Tips

This section briefly summarises a few common techniques to ensure good performance when using Turing. We refer to the Julia documentation for general techniques to ensure good performance of Julia programs.

Use multivariate distributions

It is generally preferable to use multivariate distributions if possible.

The following example:

using Turing
@model function gmodel(x)
    m ~ Normal()
    for i in eachindex(x)
        x[i] ~ Normal(m, 0.2)
    end
end
Precompiling Turing...
    806.6 ms  ? OptimizationBase
   1415.5 ms  ? Optimization
   2087.6 ms  ? OptimizationOptimJL
Info Given Turing was explicitly requested, output will be shown live 
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
   5462.2 ms  ? Turing
   5617.4 ms  ? Turing → TuringOptimExt
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Optimization...
    786.4 ms  ? OptimizationBase
Info Given Optimization was explicitly requested, output will be shown live 
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
   1377.7 ms  ? Optimization
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling OptimizationBase...
Info Given OptimizationBase was explicitly requested, output will be shown live 
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
    772.1 ms  ? OptimizationBase
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
Warning: Replacing docs for `CommonSolve.solve :: Tuple{SciMLBase.OptimizationProblem, Any, Vararg{Any}}` in module `OptimizationBase`
@ Base.Docs docs/Docs.jl:243
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
Warning: Replacing docs for `CommonSolve.init :: Tuple{SciMLBase.OptimizationProblem, Any, Vararg{Any}}` in module `OptimizationBase`
@ Base.Docs docs/Docs.jl:243
Warning: Replacing docs for `CommonSolve.solve! :: Tuple{SciMLBase.AbstractOptimizationCache}` in module `OptimizationBase`
@ Base.Docs docs/Docs.jl:243
Precompiling OptimizationOptimJL...
    778.1 ms  ? OptimizationBase
    960.7 ms  ? Optimization
Info Given OptimizationOptimJL was explicitly requested, output will be shown live 
Warning: Module Optimization with build ID ffffffff-ffff-ffff-bf8d-66e39dff6eb8 is missing from the cache.
This may mean Optimization [7f7a1694-90dd-40f0-9382-eb1efda571ba] does not support precompilation but is imported by a module that does.
@ Base loading.jl:2541
   1123.3 ms  ? OptimizationOptimJL
Warning: Module Optimization with build ID ffffffff-ffff-ffff-bf8d-66e39dff6eb8 is missing from the cache.
This may mean Optimization [7f7a1694-90dd-40f0-9382-eb1efda571ba] does not support precompilation but is imported by a module that does.
@ Base loading.jl:2541
Precompiling TuringOptimExt...
    767.8 ms  ? OptimizationBase
    972.1 ms  ? Optimization
   1117.3 ms  ? OptimizationOptimJL
   3660.8 ms  ? Turing
Info Given TuringOptimExt was explicitly requested, output will be shown live 
Warning: Module Turing with build ID ffffffff-ffff-ffff-09e2-09ebcdbdb8ed is missing from the cache.
This may mean Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] does not support precompilation but is imported by a module that does.
@ Base loading.jl:2541
    610.6 ms  ? Turing → TuringOptimExt
Warning: Module Turing with build ID ffffffff-ffff-ffff-09e2-09ebcdbdb8ed is missing from the cache.
This may mean Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] does not support precompilation but is imported by a module that does.
@ Base loading.jl:2541
gmodel (generic function with 2 methods)

can be directly expressed more efficiently using a simple transformation:

using FillArrays

@model function gmodel(x)
    m ~ Normal()
    return x ~ MvNormal(Fill(m, length(x)), 0.04 * I)
end
gmodel (generic function with 2 methods)

Choose your AD backend

Automatic differentiation (AD) makes it possible to use modern, efficient gradient-based samplers like NUTS and HMC. This, however, also means that using a performant AD system is incredibly important. Turing currently supports several AD backends, including ForwardDiff (the default), Mooncake, and ReverseDiff.

For many common types of models, the default ForwardDiff backend performs well, and there is no need to worry about changing it. However, if you need more speed, you can try different backends via the standard ADTypes interface by passing an AbstractADType to the sampler with the optional adtype argument, e.g. NUTS(; adtype = AutoMooncake()).

Generally, adtype = AutoForwardDiff() is likely to be the fastest and most reliable for models with few parameters (say, less than 20 or so), while reverse-mode backends such as AutoMooncake() or AutoReverseDiff() will perform better for models with many parameters or linear algebra operations. If in doubt, you can benchmark your model with different backends to see which one performs best. See the Automatic Differentiation page for details.

Special care for ReverseDiff with a compiled tape

For large models, the fastest option is often ReverseDiff with a compiled tape, specified as adtype=AutoReverseDiff(; compile=true). However, it is important to note that if your model contains any branching code, such as if-else statements, the gradients from a compiled tape may be inaccurate, leading to erroneous results. If you use this option for the (considerable) speedup it can provide, make sure to check your code for branching and ensure that it does not affect the gradients. It is also a good idea to verify your gradients with another backend.

Ensure that types in your model can be inferred

For efficient gradient-based inference, e.g. using HMC, NUTS or ADVI, it is important to ensure the types in your model can be inferred.

The following example with abstract types

@model function tmodel(x, y)
    p, n = size(x)
    params = Vector{Real}(undef, n)
    for i in 1:n
        params[i] ~ truncated(Normal(); lower=0)
    end

    a = x * params
    return y ~ MvNormal(a, I)
end
tmodel (generic function with 2 methods)

can be transformed into the following representation with concrete types:

@model function tmodel(x, y, ::Type{T}=Float64) where {T}
    p, n = size(x)
    params = Vector{T}(undef, n)
    for i in 1:n
        params[i] ~ truncated(Normal(); lower=0)
    end

    a = x * params
    return y ~ MvNormal(a, I)
end
tmodel (generic function with 4 methods)

Alternatively, you could use filldist in this example:

@model function tmodel(x, y)
    params ~ filldist(truncated(Normal(); lower=0), size(x, 2))
    a = x * params
    return y ~ MvNormal(a, I)
end
tmodel (generic function with 4 methods)

You can use DynamicPPL’s debugging utilities to find types in your model definition that the compiler cannot infer. These will be marked in red in the Julia REPL (much like when using the @code_warntype macro).

For example, consider the following model:

@model function tmodel(x)
    p = Vector{Real}(undef, 1)
    p[1] ~ Normal()
    p = p .+ 1
    return x ~ Normal(p[1])
end
tmodel (generic function with 6 methods)

Because the element type of p is an abstract type (Real), the compiler cannot infer a concrete type for p[1]. To detect this, we can use

model = tmodel(1.0)

using DynamicPPL
DynamicPPL.DebugUtils.model_warntype(model)

In this particular model, the following call to getindex should be highlighted in red (the exact numbers may vary):

[...]
│    %120 = p::AbstractVector
│    %121 = Base.getindex(%120, 1)::Any
[...]
Back to top