A basic scientific problem is to mathematically model a system of interest, then compare this model to the observable reality around us. Such models often involve dynamical systems of differential equations. In practice, these equations often have unkown parameters we would like to estimate. The “forward problem” of simulation consists of solving the differential equations for a given set of parameters, while the “inverse problem,” also known as parameter estimation, is the process of utilizing data to determine these model parameters. Bayesian inference provides a robust approach to parameter estimation with quantified uncertainty.
usingTuringusingDifferentialEquations# Load StatsPlots for visualizations and diagnostics.usingStatsPlotsusingLinearAlgebrausingDistributions# Set a seed for reproducibility.usingRandomRandom.seed!(14);
Precompiling Turing...
800.0 ms ? OptimizationBase
1372.8 ms ? Optimization
2122.9 ms ? OptimizationOptimJL
Info Given Turing was explicitly requested, output will be shown live
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
5415.3 ms ? Turing
5676.0 ms ? Turing → TuringOptimExt
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Optimization...
802.7 ms ? OptimizationBaseInfo Given Optimization was explicitly requested, output will be shown live
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
1407.3 ms ? Optimization
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling OptimizationBase...
Info Given OptimizationBase was explicitly requested, output will be shown live
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
782.2 ms ? OptimizationBase
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
┌ Warning: Replacing docs for `CommonSolve.solve :: Tuple{SciMLBase.OptimizationProblem, Any, Vararg{Any}}` in module `OptimizationBase`
└ @ Base.Docs docs/Docs.jl:243
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
┌ Warning: Replacing docs for `CommonSolve.init :: Tuple{SciMLBase.OptimizationProblem, Any, Vararg{Any}}` in module `OptimizationBase`
└ @ Base.Docs docs/Docs.jl:243┌ Warning: Replacing docs for `CommonSolve.solve! :: Tuple{SciMLBase.AbstractOptimizationCache}` in module `OptimizationBase`
└ @ Base.Docs docs/Docs.jl:243Precompiling OptimizationOptimJL...
775.8 ms ? OptimizationBase
987.4 ms ? Optimization
Info Given OptimizationOptimJL was explicitly requested, output will be shown live
┌ Warning: Module Optimization with build ID ffffffff-ffff-ffff-0de6-c3d9986d9e65 is missing from the cache.
│ This may mean Optimization [7f7a1694-90dd-40f0-9382-eb1efda571ba] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
1123.1 ms ? OptimizationOptimJL
┌ Warning: Module Optimization with build ID ffffffff-ffff-ffff-0de6-c3d9986d9e65 is missing from the cache.
│ This may mean Optimization [7f7a1694-90dd-40f0-9382-eb1efda571ba] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541Precompiling TuringOptimExt...
780.7 ms ? OptimizationBase
990.8 ms ? Optimization
1110.4 ms ? OptimizationOptimJL
3698.1 ms ? Turing
Info Given TuringOptimExt was explicitly requested, output will be shown live
┌ Warning: Module Turing with build ID ffffffff-ffff-ffff-358a-2ed24a82a05c is missing from the cache.
│ This may mean Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
607.7 ms ? Turing → TuringOptimExt
┌ Warning: Module Turing with build ID ffffffff-ffff-ffff-358a-2ed24a82a05c is missing from the cache.
│ This may mean Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
The Lotka–Volterra Model
The Lotka–Volterra equations, also known as the predator–prey equations, are a pair of first-order nonlinear differential equations. These differential equations are frequently used to describe the dynamics of biological systems in which two species interact, one as a predator and the other as prey. The populations change through time according to the pair of equations
where \(x(t)\) and \(y(t)\) denote the populations of prey and predator at time \(t\), respectively, and \(\alpha, \beta, \gamma, \delta\) are positive parameters. In the absence of predators, the prey population \(x\) would increase exponentially at rate \(\alpha\) (with dimensions of time-1). However, the predators kill some prey at a rate \(\beta\) (prey predator-1 time-1), which enables the predator population to increase at rate \(\delta\) (predators prey-1 time-1). Finally, predators are removed by natural mortality at rate \(\gamma\) (time-1).
We implement the Lotka–Volterra model and simulate it with parameters \(\alpha = 1.5\), \(\beta = 1\), \(\gamma = 3\), and \(\delta = 1\) and initial conditions \(x(0) = y(0) = 1\).
# Define Lotka–Volterra model.functionlotka_volterra(du, u, p, t)# Model parameters. α, β, γ, δ = p# Current state. x, y = u# Evaluate differential equations. du[1] = (α - β * y) * x # prey du[2] = (δ * x - γ) * y # predatorreturnnothingend# Define initial-value problem.u0 = [1.0, 1.0]p = [1.5, 1.0, 3.0, 1.0]tspan = (0.0, 10.0)prob =ODEProblem(lotka_volterra, u0, tspan, p)# Plot simulation.plot(solve(prob, Tsit5()))
We generate noisy observations to use for the parameter estimation tasks in this tutorial. With the saveat argument to the differential equation solver, we specify that the solution is stored only at 0.1 time units.
To make the example more realistic, we generate data as random Poisson counts based on the “true” population densities of predator and prey from the simulation. Poisson-distributed data are common in ecology (for instance, counts of animals detected by a camera trap). We’ll assume that the rate \(\lambda\), which parameterizes the Poisson distribution, is proportional to the underlying animal densities via a constant factor \(q = 1.7\).
sol =solve(prob, Tsit5(); saveat=0.1)q =1.7odedata =rand.(Poisson.(q *Array(sol)))# Plot simulation and noisy observations.plot(sol, label=["Prey""Predator"])scatter!(sol.t, odedata'; color=[12], label="")
An even more realistic example could be fitted to the famous hare-and-lynx system using the long-term trapping records of the Hudson’s Bay Company. A Stan implementation of this problem with slightly different priors can be found here. For this tutorial, though, we will stick with simulated data.
Direct Handling of Bayesian Estimation with Turing
DifferentialEquations.jl is the main Julia package for numerically solving differential equations. Its functionality is completely interoperable with Turing.jl, which means that we can directly simulate differential equations inside a Turing @model.
For the purposes of this tutorial, we choose priors for the parameters that are quite close to the ground truth. As justification, we can imagine we have preexisting estimates for the biological rates. Practically, this helps us to illustrate the results without needing to run overly long MCMC chains.
Note we also have to take special care with the ODE solver. For certain parameter combinations, the numerical solver may predict animal densities that are just barely below zero. This causes errors with the Poisson distribution, which needs a non-negative mean \(\lambda\). To avoid this happening, we tell the solver to aim for small absolute and relative errors (abstol=1e-6, reltol=1e-6). We also add a fudge factor ϵ = 1e-5 to the predicted data. Since ϵ is greater than the solver’s tolerance, it should overcome any remaining numerical error, making sure all predicted values are positive. At the same time, it is so small compared to the data that it should have a negligible effect on inference. If this approach doesn’t work, there are some more ideas to try here. In the case of continuous observations (e.g. data derived from modelling chemical reactions), it is sufficient to use a normal distribution with the mean as the data point and an appropriately chosen variance (which can itself also be a parameter with a prior distribution).
┌ Info: Found initial step size
└ ϵ = 0.05
┌ Info: Found initial step size
└ ϵ = 0.05
┌ Info: Found initial step size
└ ϵ = 0.2
Chains MCMC chain (1000×19×3 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 3
Samples per chain = 1000
Wall duration = 56.44 seconds
Compute duration = 52.61 seconds
parameters = α, β, γ, δ, q
internals = n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, lp, logprior, loglikelihood
Use `describe(chains)` for summary statistics and quantiles.
The estimated parameters are close to the parameter values the observations were generated with. We can also check visually that the chains have converged.
plot(chain)
Data retrodiction
In Bayesian analysis it is often useful to retrodict the data, i.e. generate simulated data using samples from the posterior distribution, and compare to the original data (see for instance section 3.3.2 - model checking of McElreath’s book “Statistical Rethinking”). Here, we solve the ODE for 300 randomly picked posterior samples in the chain. We plot the ensemble of solutions to check if the solution resembles the data. The 300 retrodicted time courses from the posterior are plotted in gray, the noisy observations are shown as blue and red dots, and the green and purple lines are the ODE solution that was used to generate the data.
We can see that, even though we added quite a bit of noise to the data the posterior distribution reproduces quite accurately the “true” ODE solution.
Lotka–Volterra model without data of prey
One can also perform parameter inference for a Lotka–Volterra model with incomplete data. For instance, let us suppose we have only observations of the predators but not of the prey. We can fit the model only to the \(y\) variable of the system without providing any data for \(x\):
@modelfunctionfitlv2(data::AbstractVector, prob)# Prior distributions. α ~truncated(Normal(1.5, 0.2); lower=0.5, upper=2.5) β ~truncated(Normal(1.1, 0.2); lower=0, upper=2) γ ~truncated(Normal(3.0, 0.2); lower=1, upper=4) δ ~truncated(Normal(1.0, 0.2); lower=0, upper=2) q ~truncated(Normal(1.7, 0.2); lower=0, upper=3)# Simulate Lotka–Volterra model but save only the second state of the system (predators). p = [α, β, γ, δ] predicted =solve(prob, Tsit5(); p=p, saveat=0.1, save_idxs=2, abstol=1e-6, reltol=1e-6) ϵ =1e-5# Observations of the predators. data ~arraydist(Poisson.(q .* predicted.u .+ ϵ))returnnothingendmodel2 =fitlv2(odedata[2, :], prob)# Sample 3 independent chains.chain2 =sample(model2, NUTS(0.45), MCMCSerial(), 5000, 3; progress=false)
┌ Info: Found initial step size
└ ϵ = 0.2
┌ Info: Found initial step size
└ ϵ = 0.05
┌ Info: Found initial step size
└ ϵ = 0.2
Chains MCMC chain (5000×19×3 Array{Float64, 3}):
Iterations = 1001:1:6000
Number of chains = 3
Samples per chain = 5000
Wall duration = 29.56 seconds
Compute duration = 29.09 seconds
parameters = α, β, γ, δ, q
internals = n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, lp, logprior, loglikelihood
Use `describe(chains)` for summary statistics and quantiles.
Again we inspect the trajectories of 300 randomly selected posterior samples.
Note that here the observations of the prey (blue dots) were not used in the parameter estimation! Yet, the model can predict the values of \(x\) relatively accurately, albeit with a wider distribution of solutions, reflecting the greater uncertainty in the prediction of the \(x\) values.
Inference of Delay Differential Equations
Here we show an example of inference with another type of differential equation: a delay differential equation (DDE). DDEs are differential equations where derivatives are functions of values at an earlier point in time. This is useful to model a delayed effect, such as the incubation time of a virus.
Here is a delayed version of the Lotka–Volterra system:
where \(\tau\) is a (positive) delay and \(x(t-\tau)\) is the variable \(x\) at an earlier time point \(t - \tau\).
The initial-value problem of the delayed system can be implemented as a DDEProblem. As described in the DDE example, here the function h is the history function that can be used to obtain a state at an earlier time point. Again we use parameters \(\alpha = 1.5\), \(\beta = 1\), \(\gamma = 3\), and \(\delta = 1\) and initial conditions \(x(0) = y(0) = 1\). Moreover, we assume \(x(t) = 1\) for \(t < 0\).
functiondelay_lotka_volterra(du, u, h, p, t)# Model parameters. α, β, γ, δ = p# Current state. x, y = u# Evaluate differential equations du[1] = α *h(p, t -1; idxs=1) - β * x * y du[2] =-γ * y + δ * x * yreturnnothingend# Define initial-value problem.p = (1.5, 1.0, 3.0, 1.0)u0 = [1.0; 1.0]tspan = (0.0, 10.0)h(p, t; idxs::Int) =1.0prob_dde =DDEProblem(delay_lotka_volterra, u0, h, tspan, p);
We generate observations by sampling from the corresponding Poisson distributions derived from the simulation results:
┌ Info: Found initial step size
└ ϵ = 0.05
┌ Info: Found initial step size
└ ϵ = 0.2
┌ Info: Found initial step size
└ ϵ = 0.05
Chains MCMC chain (300×19×3 Array{Float64, 3}):
Iterations = 151:1:450
Number of chains = 3
Samples per chain = 300
Wall duration = 12.58 seconds
Compute duration = 12.34 seconds
parameters = α, β, γ, δ, q
internals = n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, lp, logprior, loglikelihood
Use `describe(chains)` for summary statistics and quantiles.
plot(chain_dde)
Finally, we plot trajectories of 300 randomly selected samples from the posterior. Again, the dots indicate our observations, the colored lines are the “true” simulations without noise, and the gray lines are trajectories from the posterior samples.
The fit is pretty good even though the data was quite noisy to start.
Scaling to Large Models: Adjoint Sensitivities
DifferentialEquations.jl’s efficiency for large stiff models has been shown in multiple benchmarks. To learn more about how to optimize solving performance for stiff problems you can take a look at the docs.
Sensitivity analysis is provided by the SciMLSensitivity.jl package, which forms part of SciML’s differential equation suite. The model sensitivities are the derivatives of the solution with respect to the parameters. Specifically, the local sensitivity of the solution to a parameter is defined by how much the solution would change if the parameter were changed by a small amount. Sensitivity analysis provides a cheap way to calculate the gradient of the solution which can be used in parameter estimation and other optimization tasks. The sensitivity analysis methods in SciMLSensitivity.jl are based on automatic differentiation (AD), and are compatible with many of Julia’s AD backends. More details on the mathematical theory that underpins these methods can be found in the SciMLSensitivity documentation.
To enable sensitivity analysis, you will need to import SciMLSensitivity, and also use one of the AD backends that is compatible with SciMLSensitivity.jl when sampling. For example, if we wanted to use Mooncake.jl, we could run:
importMooncakeimportSciMLSensitivity# Define the AD backend to useadtype =AutoMooncake()# Sample a single chain with 1000 samples using Mooncakesample(model, NUTS(; adtype=adtype), 1000; progress=false)
Precompiling OptimizationEnzymeExt...
831.8 ms ? OptimizationBaseInfo Given OptimizationEnzymeExt was explicitly requested, output will be shown live
┌ Warning: Module OptimizationBase with build ID ffffffff-ffff-ffff-5530-ec92ce63a403 is missing from the cache.
│ This may mean OptimizationBase [bca83a33-5cc9-4baa-983d-23429ab6bcbb] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
645.0 ms ? OptimizationBase → OptimizationEnzymeExt┌ Warning: Module OptimizationBase with build ID ffffffff-ffff-ffff-5530-ec92ce63a403 is missing from the cache.
│ This may mean OptimizationBase [bca83a33-5cc9-4baa-983d-23429ab6bcbb] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541Precompiling OptimizationZygoteExt...
808.0 ms ? OptimizationBaseInfo Given OptimizationZygoteExt was explicitly requested, output will be shown live
┌ Warning: Module OptimizationBase with build ID ffffffff-ffff-ffff-5530-ec92ce63a403 is missing from the cache.
│ This may mean OptimizationBase [bca83a33-5cc9-4baa-983d-23429ab6bcbb] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
650.7 ms ? OptimizationBase → OptimizationZygoteExt┌ Warning: Module OptimizationBase with build ID ffffffff-ffff-ffff-5530-ec92ce63a403 is missing from the cache.
│ This may mean OptimizationBase [bca83a33-5cc9-4baa-983d-23429ab6bcbb] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541┌ Info: Found initial step size
└ ϵ = 0.2
Chains MCMC chain (1000×19×1 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 1
Samples per chain = 1000
Wall duration = 279.98 seconds
Compute duration = 279.98 seconds
parameters = α, β, γ, δ, q
internals = n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, lp, logprior, loglikelihood
Use `describe(chains)` for summary statistics and quantiles.
In this case, SciMLSensitivity will automatically choose an appropriate sensitivity analysis algorithm for you. You can also manually specify an algorithm by providing the sensealg keyword argument to the solve function; the existing algorithms are covered in this page of the SciMLSensitivity docs.
For more examples of adjoint usage on large parameter models, consult the DiffEqFlux documentation.