Automatic Differentiation

What is Automatic Differentiation?

Automatic differentiation (AD) is a technique used in Turing.jl to evaluate the gradient of a function at a given set of arguments. In the context of Turing.jl, the function being differentiated is the log probability density of a model, and the arguments are the parameters of the model (i.e. the values of the random variables). The gradient of the log probability density is used by various algorithms in Turing.jl, such as HMC (including NUTS), mode estimation (which uses gradient-based optimization), and variational inference.

The Julia ecosystem has a number of AD libraries. You can switch between these using the unified ADTypes.jl interface, which for a given AD backend, provides types such as AutoBackend (see the documentation for more details). For example, to use the Mooncake.jl package for AD, you can run the following:

# Turing re-exports AutoForwardDiff, AutoReverseDiff, and AutoMooncake.
# Other ADTypes must be explicitly imported from ADTypes.jl or
# DifferentiationInterface.jl.
using Turing
setprogress!(false)

# Note that if you specify a custom AD backend, you must also import it.
import Mooncake

@model function f()
    x ~ Normal()
    # Rest of your model here
end

sample(f(), HMC(0.1, 5; adtype=AutoMooncake(; config=nothing)), 100)
[ Info: [Turing]: progress logging is disabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as false
Chains MCMC chain (100×11×1 Array{Float64, 3}):

Iterations        = 1:1:100
Number of chains  = 1
Samples per chain = 100
Wall duration     = 67.33 seconds
Compute duration  = 67.33 seconds
parameters        = x
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           x   -0.1689    1.0376    0.3195    11.1261    50.6924    1.0660     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           x   -2.0115   -0.9245   -0.0993    0.6586    1.6279

By default, if you do not specify a backend, Turing will default to ForwardDiff.jl. In this case, you do not need to import ForwardDiff, as it is already a dependency of Turing.

Choosing an AD Backend

There are two aspects to choosing an AD backend: firstly, what backends are available; and secondly, which backend is best for your model.

Usable AD Backends

Turing.jl uses the functionality in DifferentiationInterface.jl (‘DI’) to interface with AD libraries in a unified way. In principle, any AD library that DI provides an interface for can be used with Turing; you should consult the DI documentation for an up-to-date list of compatible AD libraries.

Note, however, that not all AD libraries in there are thoroughly tested on Turing models. Thus, it is possible that some of them will either error (because they don’t know how to differentiate through Turing’s code), or maybe even silently give incorrect results (if you are very unlucky). Turing is most extensively tested with ForwardDiff.jl (the default), ReverseDiff.jl, and Mooncake.jl. We also run a smaller set of tests with Enzyme.jl.

ADTests

Before describing how to choose the best AD backend for your model, we should mention that we also publish a table of benchmarks for various models and AD backends in the ADTests website. These models aim to capture a variety of different features of Turing.jl and Julia in general, so that you can see which AD backends may be compatible with your model. Benchmarks are also included, although it should be noted that many of the models in ADTests are small and thus the timings may not be representative of larger, real-life models.

If you have suggestions for other models to include, please do let us know by creating an issue on GitHub!

The Best AD Backend for Your Model

Given the number of possible backends, how do you choose the best one for your model?

A simple heuristic is to look at the number of parameters in your model. The log density of the model, i.e. the function being differentiated, is a function that goes from \(\mathbb{R}^n \to \mathbb{R}\), where \(n\) is the number of parameters in your model. For models with a small number of parameters (say up to ~ 20), forward-mode AD (e.g. ForwardDiff) is generally faster due to a smaller overhead. On the other hand, for models with a large number of parameters, reverse-mode AD (e.g. ReverseDiff or Mooncake) is generally faster as it computes the gradients with respect to all parameters in a single pass.

The most exact way to ensure you are using the fastest AD that works for your problem is to benchmark them using the functionality in DynamicPPL (see the API documentation):

using ADTypes
using DynamicPPL.TestUtils.AD: run_ad, ADResult
using ForwardDiff, ReverseDiff

@model function gdemo(x, y)
~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s²))
    x ~ Normal(m, sqrt(s²))
    return y ~ Normal(m, sqrt(s²))
end
model = gdemo(1.5, 2)

for adtype in [AutoForwardDiff(), AutoReverseDiff()]
    result = run_ad(model, adtype; benchmark=true)
    @show result.time_vs_primal
end
[ Info: Running AD on gdemo with ADTypes.AutoForwardDiff()
       params : [-0.22646121565377667, -0.7010358307415183]
       actual : (-11.450427132876431, [8.18345036538685, 7.027159490855991])
     expected : (-11.450427132876431, [8.18345036538685, 7.027159490855991])
grad / primal : 1.3685095127434868
result.time_vs_primal = 1.3685095127434868
[ Info: Running AD on gdemo with ADTypes.AutoReverseDiff()
       params : [0.7438301214313114, 0.5227486505377502]
       actual : (-5.399368509888411, [-1.2636279373989772, 0.9181433940487849])
     expected : (-5.399368509888411, [-1.2636279373989772, 0.918143394048785])
grad / primal : 31.71040785704542
result.time_vs_primal = 31.71040785704542

In this specific instance, ForwardDiff is clearly faster (due to the small size of the model).

A note about ReverseDiff’s compile argument

The additional keyword argument compile=true for AutoReverseDiff specifies whether to pre-record the tape only once and reuse it later. By default, this is set to false, which means no pre-recording. Setting compile=true can substantially improve performance, but risks silently incorrect results if not used with care. Pre-recorded tapes should only be used if you are absolutely certain that the sequence of operations performed in your code does not change between different executions of your model.

Compositional Sampling with Differing AD Modes

When using Gibbs sampling, Turing also supports mixed automatic differentiation methods for different variable spaces. The following snippet shows how one can use ForwardDiff to sample the mean (m) parameter, and ReverseDiff for the variance (s) parameter:

using Turing
using ReverseDiff

# Sample using Gibbs and varying autodiff backends.
c = sample(
    gdemo(1.5, 2),
    Gibbs(
        :m => HMC(0.1, 5; adtype=AutoForwardDiff()),
        :=> HMC(0.1, 5; adtype=AutoReverseDiff()),
    ),
    1000,
    progress=false,
)
Chains MCMC chain (1000×3×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 14.56 seconds
Compute duration  = 14.56 seconds
parameters        = s², m
internals         = lp

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

          s²    2.4812    2.4768    0.3069   109.0650    87.3141    1.0035     ⋯
           m    1.4194    0.9456    0.1247    71.0295    57.6947    1.0018     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

          s²    0.6652    1.1763    1.7349    2.6510   10.2629
           m   -0.0115    0.8007    1.2944    1.8525    3.6754
Back to top