Turing AD tests
Turing.jl documentation | Turing.jl GitHub | Source code for these tests
This page is intended as a brief overview of how different AD backends perform on a variety of Turing.jl models. Note that the inclusion of any AD backend here does not imply an endorsement from the Turing team; this table is purely for information.
- The definitions of the models and AD types below can be found on GitHub.
- Numbers indicate the time taken to calculate the gradient of the log density of the model using the specified AD type, divided by the time taken to calculate the log density itself (in AD speak, the primal). Basically: smaller means faster.
- 'wrong' means that AD ran but the result was not correct. If this happens you should be very wary! Note that this is done by comparing against the result obtained using ForwardDiff, i.e., ForwardDiff is by definition always 'correct'.
- 'error' means that AD didn't run.
- Some of the 'wrong' or 'error' entries have question marks next to them. These will link to a GitHub issue or other page that describes the problem.
Results
Model name \ AD type | EnzymeForward | EnzymeReverse | FiniteDifferences | ForwardDiff | Mooncake | ReverseDiff | ReverseDiffCompiled | Zygote |
---|---|---|---|---|---|---|---|---|
assume_beta | 2.524 | (?) error | 26.675 | 1.780 | 8.477 | 16.842 | 2.021 | 600.984 |
assume_dirichlet | 3.216 | (?) error | wrong | 1.469 | 11.310 | 27.315 | 3.250 | 643.017 |
assume_lkjcholu | 16.503 | (?) error | wrong | 7.038 | 11.270 | 29.953 | 4.759 | error |
assume_mvnormal | (?) error | 3.918 | 44.905 | 1.315 | 8.621 | 17.098 | 2.188 | 548.526 |
assume_normal | 2.790 | (?) error | 30.070 | 1.504 | 9.377 | 23.569 | 2.882 | 826.264 |
assume_submodel | 2.956 | 4.067 | 32.747 | 1.230 | 8.325 | 11.507 | 1.208 | 712.017 |
assume_wishart | (?) error | (?) error | error | 1.085 | 10.582 | 28.520 | 3.289 | 299.048 |
control_flow | 2.761 | 3.194 | 41.135 | 1.452 | 9.511 | 20.245 | wrong | 957.387 |
dot_assume | 3.145 | 2.330 | 82.908 | 1.345 | 7.622 | 29.399 | 3.297 | error |
dot_assume_observe_index | wrong | wrong | 57.018 | 1.521 | 7.293 | 45.622 | 4.681 | error |
dot_observe | 3.015 | 2.841 | 31.025 | 1.752 | 10.603 | 62.663 | 6.569 | error |
dynamic_constraint | 2.638 | 2.793 | 40.342 | 1.326 | 8.887 | 26.579 | 3.113 | 1065.369 |
multiple_constraints_same_var | 4.436 | 23.112 | wrong | 1.254 | 16.918 | 19.399 | 1.925 | error |
n010 | 4.576 | 2.064 | 149.439 | 1.798 | 5.541 | 35.134 | 3.469 | error |
n050 | 20.138 | 2.882 | 642.196 | 9.590 | 5.266 | 40.796 | 3.967 | error |
n100 | 41.451 | 2.872 | 1251.383 | 16.076 | 4.759 | 41.108 | 3.910 | error |
n500 | 205.450 | 2.634 | 6835.753 | 60.516 | 4.763 | 34.646 | 3.385 | error |
observe_index | 2.988 | 2.658 | 28.706 | 1.797 | 10.774 | 61.258 | 6.520 | error |
observe_literal | 2.999 | 3.321 | 30.317 | 1.569 | 12.260 | 36.578 | 4.104 | 1182.073 |
observe_multivariate | 2.961 | 4.546 | 55.125 | 1.514 | 7.064 | 32.938 | 3.156 | error |
observe_submodel | 2.906 | 4.738 | 19.526 | 1.290 | 8.018 | 14.172 | 1.438 | 809.574 |