using Turing
using DifferentialEquations
# Load StatsPlots for visualizations and diagnostics.
using StatsPlots
using LinearAlgebra
using Distributions
# Set a seed for reproducibility.
using Random
Random.seed!(14);
Bayesian Differential Equations
A basic scientific problem is to mathematically model a system of interest, then compare this model to the observable reality around us. Such models often involve dynamical systems of differential equations. In practice, these equations often have unkown parameters we would like to estimate. The “forward problem” of simulation consists of solving the differential equations for a given set of parameters, while the “inverse problem,” also known as parameter estimation, is the process of utilizing data to determine these model parameters. Bayesian inference provides a robust approach to parameter estimation with quantified uncertainty.
The Lotka–Volterra Model
The Lotka–Volterra equations, also known as the predator–prey equations, are a pair of first-order nonlinear differential equations. These differential equations are frequently used to describe the dynamics of biological systems in which two species interact, one as a predator and the other as prey. The populations change through time according to the pair of equations
\[ \begin{aligned} \frac{\mathrm{d}x}{\mathrm{d}t} &= (\alpha - \beta y(t))x(t), \\ \frac{\mathrm{d}y}{\mathrm{d}t} &= (\delta x(t) - \gamma)y(t), \end{aligned} \]
where \(x(t)\) and \(y(t)\) denote the populations of prey and predator at time \(t\), respectively, and \(\alpha, \beta, \gamma, \delta\) are positive parameters. In the absence of predators, the prey population \(x\) would increase exponentially at rate \(\alpha\) (with dimensions of time-1). However, the predators kill some prey at a rate \(\beta\) (prey predator-1 time-1), which enables the predator population to increase at rate \(\delta\) (predators prey-1 time-1). Finally, predators are removed by natural mortality at rate \(\gamma\) (time-1).
We implement the Lotka–Volterra model and simulate it with parameters \(\alpha = 1.5\), \(\beta = 1\), \(\gamma = 3\), and \(\delta = 1\) and initial conditions \(x(0) = y(0) = 1\).
# Define Lotka–Volterra model.
function lotka_volterra(du, u, p, t)
# Model parameters.
= p
α, β, γ, δ # Current state.
= u
x, y
# Evaluate differential equations.
1] = (α - β * y) * x # prey
du[2] = (δ * x - γ) * y # predator
du[
return nothing
end
# Define initial-value problem.
= [1.0, 1.0]
u0 = [1.5, 1.0, 3.0, 1.0]
p = (0.0, 10.0)
tspan = ODEProblem(lotka_volterra, u0, tspan, p)
prob
# Plot simulation.
plot(solve(prob, Tsit5()))
We generate noisy observations to use for the parameter estimation tasks in this tutorial. With the saveat
argument to the differential equation solver, we specify that the solution is stored only at 0.1
time units.
To make the example more realistic, we generate data as random Poisson counts based on the “true” population densities of predator and prey from the simulation. Poisson-distributed data are common in ecology (for instance, counts of animals detected by a camera trap). We’ll assume that the rate \(\lambda\), which parameterizes the Poisson distribution, is proportional to the underlying animal densities via a constant factor \(q = 1.7\).
= solve(prob, Tsit5(); saveat=0.1)
sol = 1.7
q = rand.(Poisson.(q * Array(sol)))
odedata
# Plot simulation and noisy observations.
plot(sol, label=["Prey" "Predator"])
scatter!(sol.t, odedata'; color=[1 2], label="")
An even more realistic example could be fitted to the famous hare-and-lynx system using the long-term trapping records of the Hudson’s Bay Company. A Stan implementation of this problem with slightly different priors can be found here. For this tutorial, though, we will stick with simulated data.
Direct Handling of Bayesian Estimation with Turing
DifferentialEquations.jl is the main Julia package for numerically solving differential equations. Its functionality is completely interoperable with Turing.jl, which means that we can directly simulate differential equations inside a Turing @model
.
For the purposes of this tutorial, we choose priors for the parameters that are quite close to the ground truth. As justification, we can imagine we have preexisting estimates for the biological rates. Practically, this helps us to illustrate the results without needing to run overly long MCMC chains.
Note we also have to take special care with the ODE solver. For certain parameter combinations, the numerical solver may predict animal densities that are just barely below zero. This causes errors with the Poisson distribution, which needs a non-negative mean \(\lambda\). To avoid this happening, we tell the solver to aim for small abolute and relative errors (abstol=1e-6, reltol=1e-6
). We also add a fudge factor ϵ = 1e-5
to the predicted data. Since ϵ
is greater than the solver’s tolerance, it should overcome any remaining numerical error, making sure all predicted values are positive. At the same time, it is so small compared to the data that it should have a negligible effect on inference. If this approach doesn’t work, there are some more ideas to try here.
@model function fitlv(data, prob)
# Prior distributions.
~ truncated(Normal(1.5, 0.2); lower=0.5, upper=2.5)
α ~ truncated(Normal(1.1, 0.2); lower=0, upper=2)
β ~ truncated(Normal(3.0, 0.2); lower=1, upper=4)
γ ~ truncated(Normal(1.0, 0.2); lower=0, upper=2)
δ ~ truncated(Normal(1.7, 0.2); lower=0, upper=3)
q
# Simulate Lotka–Volterra model.
= [α, β, γ, δ]
p = solve(prob, Tsit5(); p=p, saveat=0.1, abstol=1e-6, reltol=1e-6)
predicted = 1e-5
ϵ
# Observations.
for i in eachindex(predicted)
:, i] ~ arraydist(Poisson.(q .* predicted[i] .+ ϵ))
data[end
return nothing
end
= fitlv(odedata, prob)
model
# Sample 3 independent chains with forward-mode automatic differentiation (the default).
= sample(model, NUTS(), MCMCSerial(), 1000, 3; progress=false) chain
┌ Info: Found initial step size └ ϵ = 0.05 ┌ Info: Found initial step size └ ϵ = 0.05 ┌ Info: Found initial step size └ ϵ = 0.2
Chains MCMC chain (1000×17×3 Array{Float64, 3}): Iterations = 501:1:1500 Number of chains = 3 Samples per chain = 1000 Wall duration = 49.27 seconds Compute duration = 46.69 seconds parameters = α, β, γ, δ, q internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size Summary Statistics parameters mean std mcse ess_bulk ess_tail rhat ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ α 1.5257 0.0611 0.0018 1107.0992 936.2641 1.0028 ⋯ β 0.9674 0.0674 0.0018 1357.8551 1415.4968 1.0011 ⋯ γ 3.0136 0.1441 0.0040 1285.9869 1138.0021 1.0019 ⋯ δ 0.9799 0.0802 0.0025 1076.4832 958.9411 1.0028 ⋯ q 1.6761 0.0983 0.0027 1298.3271 1141.0664 1.0033 ⋯ 1 column omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 α 1.4076 1.4846 1.5249 1.5665 1.6458 β 0.8434 0.9212 0.9649 1.0122 1.1025 γ 2.7366 2.9182 3.0111 3.1076 3.3108 δ 0.8321 0.9244 0.9781 1.0314 1.1473 q 1.4832 1.6118 1.6724 1.7422 1.8719
The estimated parameters are close to the parameter values the observations were generated with. We can also check visually that the chains have converged.
plot(chain)
Data retrodiction
In Bayesian analysis it is often useful to retrodict the data, i.e. generate simulated data using samples from the posterior distribution, and compare to the original data (see for instance section 3.3.2 - model checking of McElreath’s book “Statistical Rethinking”). Here, we solve the ODE for 300 randomly picked posterior samples in the chain
. We plot the ensemble of solutions to check if the solution resembles the data. The 300 retrodicted time courses from the posterior are plotted in gray, the noisy observations are shown as blue and red dots, and the green and purple lines are the ODE solution that was used to generate the data.
plot(; legend=false)
= sample(chain[[:α, :β, :γ, :δ]], 300; replace=false)
posterior_samples for p in eachrow(Array(posterior_samples))
= solve(prob, Tsit5(); p=p, saveat=0.1)
sol_p plot!(sol_p; alpha=0.1, color="#BBBBBB")
end
# Plot simulation and noisy observations.
plot!(sol; color=[1 2], linewidth=1)
scatter!(sol.t, odedata'; color=[1 2])
We can see that, even though we added quite a bit of noise to the data the posterior distribution reproduces quite accurately the “true” ODE solution.
Lotka–Volterra model without data of prey
One can also perform parameter inference for a Lotka–Volterra model with incomplete data. For instance, let us suppose we have only observations of the predators but not of the prey. We can fit the model only to the \(y\) variable of the system without providing any data for \(x\):
@model function fitlv2(data::AbstractVector, prob)
# Prior distributions.
~ truncated(Normal(1.5, 0.2); lower=0.5, upper=2.5)
α ~ truncated(Normal(1.1, 0.2); lower=0, upper=2)
β ~ truncated(Normal(3.0, 0.2); lower=1, upper=4)
γ ~ truncated(Normal(1.0, 0.2); lower=0, upper=2)
δ ~ truncated(Normal(1.7, 0.2); lower=0, upper=3)
q
# Simulate Lotka–Volterra model but save only the second state of the system (predators).
= [α, β, γ, δ]
p = solve(prob, Tsit5(); p=p, saveat=0.1, save_idxs=2, abstol=1e-6, reltol=1e-6)
predicted = 1e-5
ϵ
# Observations of the predators.
~ arraydist(Poisson.(q .* predicted.u .+ ϵ))
data
return nothing
end
= fitlv2(odedata[2, :], prob)
model2
# Sample 3 independent chains.
= sample(model2, NUTS(0.45), MCMCSerial(), 5000, 3; progress=false) chain2
┌ Info: Found initial step size └ ϵ = 0.025 ┌ Info: Found initial step size └ ϵ = 0.05 ┌ Info: Found initial step size └ ϵ = 0.05
Chains MCMC chain (5000×17×3 Array{Float64, 3}): Iterations = 1001:1:6000 Number of chains = 3 Samples per chain = 5000 Wall duration = 28.94 seconds Compute duration = 28.57 seconds parameters = α, β, γ, δ, q internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size Summary Statistics parameters mean std mcse ess_bulk ess_tail rhat e ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ α 1.5718 0.0894 0.0080 152.4329 112.4395 1.0344 ⋯ β 0.9900 0.1035 0.0071 212.4749 500.9032 1.0331 ⋯ γ 3.0055 0.1490 0.0114 181.6808 150.3449 1.0293 ⋯ δ 0.9180 0.1204 0.0098 157.9919 112.0723 1.0314 ⋯ q 1.6494 0.1631 0.0115 200.3701 289.9853 1.0198 ⋯ 1 column omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 α 1.4162 1.5117 1.5638 1.6228 1.7840 β 0.7967 0.9173 0.9897 1.0561 1.2007 γ 2.6919 2.9107 3.0099 3.1089 3.2791 δ 0.6684 0.8376 0.9203 0.9953 1.1584 q 1.3284 1.5393 1.6491 1.7579 1.9751
Again we inspect the trajectories of 300 randomly selected posterior samples.
plot(; legend=false)
= sample(chain2[[:α, :β, :γ, :δ]], 300; replace=false)
posterior_samples for p in eachrow(Array(posterior_samples))
= solve(prob, Tsit5(); p=p, saveat=0.1)
sol_p plot!(sol_p; alpha=0.1, color="#BBBBBB")
end
# Plot simulation and noisy observations.
plot!(sol; color=[1 2], linewidth=1)
scatter!(sol.t, odedata'; color=[1 2])
Note that here the observations of the prey (blue dots) were not used in the parameter estimation! Yet, the model can predict the values of \(x\) relatively accurately, albeit with a wider distribution of solutions, reflecting the greater uncertainty in the prediction of the \(x\) values.
Inference of Delay Differential Equations
Here we show an example of inference with another type of differential equation: a delay differential equation (DDE). DDEs are differential equations where derivatives are functions of values at an earlier point in time. This is useful to model a delayed effect, such as the incubation time of a virus.
Here is a delayed version of the Lotka–Volterra system:
\[ \begin{aligned} \frac{\mathrm{d}x}{\mathrm{d}t} &= \alpha x(t-\tau) - \beta y(t) x(t),\\ \frac{\mathrm{d}y}{\mathrm{d}t} &= - \gamma y(t) + \delta x(t) y(t), \end{aligned} \]
where \(\tau\) is a (positive) delay and \(x(t-\tau)\) is the variable \(x\) at an earlier time point \(t - \tau\).
The initial-value problem of the delayed system can be implemented as a DDEProblem
. As described in the DDE example, here the function h
is the history function that can be used to obtain a state at an earlier time point. Again we use parameters \(\alpha = 1.5\), \(\beta = 1\), \(\gamma = 3\), and \(\delta = 1\) and initial conditions \(x(0) = y(0) = 1\). Moreover, we assume \(x(t) = 1\) for \(t < 0\).
function delay_lotka_volterra(du, u, h, p, t)
# Model parameters.
= p
α, β, γ, δ
# Current state.
= u
x, y # Evaluate differential equations
1] = α * h(p, t - 1; idxs=1) - β * x * y
du[2] = -γ * y + δ * x * y
du[
return nothing
end
# Define initial-value problem.
= (1.5, 1.0, 3.0, 1.0)
p = [1.0; 1.0]
u0 = (0.0, 10.0)
tspan h(p, t; idxs::Int) = 1.0
= DDEProblem(delay_lotka_volterra, u0, h, tspan, p); prob_dde
We generate observations by sampling from the corresponding Poisson distributions derived from the simulation results:
= solve(prob_dde; saveat=0.1)
sol_dde = rand.(Poisson.(q .* Array(sol_dde)))
ddedata
# Plot simulation and noisy observations.
plot(sol_dde)
scatter!(sol_dde.t, ddedata'; color=[1 2], label="")
Now we define the Turing model for the Lotka–Volterra model with a delay, and sample 3 independent chains.
@model function fitlv_dde(data, prob)
# Prior distributions.
~ truncated(Normal(1.5, 0.2); lower=0.5, upper=2.5)
α ~ truncated(Normal(1.1, 0.2); lower=0, upper=2)
β ~ truncated(Normal(3.0, 0.2); lower=1, upper=4)
γ ~ truncated(Normal(1.0, 0.2); lower=0, upper=2)
δ ~ truncated(Normal(1.7, 0.2); lower=0, upper=3)
q
# Simulate Lotka–Volterra model.
= [α, β, γ, δ]
p = solve(prob, MethodOfSteps(Tsit5()); p=p, saveat=0.1, abstol=1e-6, reltol=1e-6)
predicted = 1e-5
ϵ
# Observations.
for i in eachindex(predicted)
:, i] ~ arraydist(Poisson.(q .* predicted[i] .+ ϵ))
data[end
end
= fitlv_dde(ddedata, prob_dde)
model_dde
= sample(model_dde, NUTS(), MCMCSerial(), 300, 3; progress=false) chain_dde
┌ Info: Found initial step size └ ϵ = 0.2 ┌ Info: Found initial step size └ ϵ = 0.05 ┌ Info: Found initial step size └ ϵ = 0.2
Chains MCMC chain (300×17×3 Array{Float64, 3}): Iterations = 151:1:450 Number of chains = 3 Samples per chain = 300 Wall duration = 12.3 seconds Compute duration = 12.05 seconds parameters = α, β, γ, δ, q internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size Summary Statistics parameters mean std mcse ess_bulk ess_tail rhat e ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ α 1.5849 0.1120 0.0058 384.1059 268.6046 1.0046 ⋯ β 1.0382 0.0925 0.0037 683.5960 541.1522 1.0011 ⋯ γ 3.0831 0.1463 0.0061 574.0471 402.3024 1.0021 ⋯ δ 0.9885 0.0814 0.0042 376.4300 384.6211 1.0086 ⋯ q 1.6828 0.1178 0.0053 490.7867 566.5574 1.0035 ⋯ 1 column omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 α 1.3825 1.5112 1.5745 1.6549 1.8302 β 0.8799 0.9734 1.0326 1.0943 1.2372 γ 2.7909 2.9800 3.0853 3.1868 3.3632 δ 0.8319 0.9342 0.9890 1.0410 1.1479 q 1.4649 1.6006 1.6831 1.7656 1.9028
plot(chain_dde)
Finally, we plot trajectories of 300 randomly selected samples from the posterior. Again, the dots indicate our observations, the colored lines are the “true” simulations without noise, and the gray lines are trajectories from the posterior samples.
plot(; legend=false)
= sample(chain_dde[[:α, :β, :γ, :δ]], 300; replace=false)
posterior_samples for p in eachrow(Array(posterior_samples))
= solve(prob_dde, MethodOfSteps(Tsit5()); p=p, saveat=0.1)
sol_p plot!(sol_p; alpha=0.1, color="#BBBBBB")
end
# Plot simulation and noisy observations.
plot!(sol_dde; color=[1 2], linewidth=1)
scatter!(sol_dde.t, ddedata'; color=[1 2])
The fit is pretty good even though the data was quite noisy to start.
Scaling to Large Models: Adjoint Sensitivities
DifferentialEquations.jl’s efficiency for large stiff models has been shown in multiple benchmarks. To learn more about how to optimize solving performance for stiff problems you can take a look at the docs.
Sensitivity analysis is provided by the SciMLSensitivity.jl package, which forms part of SciML’s differential equation suite. The model sensitivities are the derivatives of the solution with respect to the parameters. Specifically, the local sensitivity of the solution to a parameter is defined by how much the solution would change if the parameter were changed by a small amount. Sensitivity analysis provides a cheap way to calculate the gradient of the solution which can be used in parameter estimation and other optimization tasks. The sensitivity analysis methods in SciMLSensitivity.jl are based on automatic differentiation (AD), and are compatible with many of Julia’s AD backends. More details on the mathematical theory that underpins these methods can be found in the SciMLSensitivity documentation.
To enable sensitivity analysis, you will need to import SciMLSensitivity
, and also use one of the AD backends that is compatible with SciMLSensitivity.jl when sampling. For example, if we wanted to use Mooncake.jl, we could run:
import Mooncake
import SciMLSensitivity
# Define the AD backend to use
= AutoMooncake(; config=nothing)
adtype
# Sample a single chain with 1000 samples using Mooncake
sample(model, NUTS(; adtype=adtype), 1000; progress=false)
┌ Info: Found initial step size └ ϵ = 0.05
Chains MCMC chain (1000×17×1 Array{Float64, 3}): Iterations = 501:1:1500 Number of chains = 1 Samples per chain = 1000 Wall duration = 290.5 seconds Compute duration = 290.5 seconds parameters = α, β, γ, δ, q internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size Summary Statistics parameters mean std mcse ess_bulk ess_tail rhat e ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ α 1.5273 0.0547 0.0027 404.4859 422.6494 1.0016 ⋯ β 0.9669 0.0673 0.0029 546.1516 624.4842 1.0010 ⋯ γ 3.0064 0.1295 0.0058 495.4367 692.5238 1.0012 ⋯ δ 0.9776 0.0727 0.0038 371.9380 526.1526 1.0007 ⋯ q 1.6749 0.0985 0.0043 532.7613 566.3005 1.0031 ⋯ 1 column omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 α 1.4233 1.4876 1.5264 1.5653 1.6266 β 0.8414 0.9206 0.9650 1.0103 1.1068 γ 2.7681 2.9194 3.0033 3.0966 3.2566 δ 0.8502 0.9259 0.9714 1.0292 1.1346 q 1.4899 1.6077 1.6701 1.7408 1.8712
In this case, SciMLSensitivity will automatically choose an appropriate sensitivity analysis algorithm for you. You can also manually specify an algorithm by providing the sensealg
keyword argument to the solve
function; the existing algorithms are covered in this page of the SciMLSensitivity docs.
For more examples of adjoint usage on large parameter models, consult the DiffEqFlux documentation.