Bayesian Differential Equations

A basic scientific problem is to mathematically model a system of interest, then compare this model to the observable reality around us. Such models often involve dynamical systems of differential equations. In practice, these equations often have unkown parameters we would like to estimate. The “forward problem” of simulation consists of solving the differential equations for a given set of parameters, while the “inverse problem,” also known as parameter estimation, is the process of utilizing data to determine these model parameters. Bayesian inference provides a robust approach to parameter estimation with quantified uncertainty.

using Turing
using DifferentialEquations
# Load StatsPlots for visualizations and diagnostics.
using StatsPlots
using LinearAlgebra
using Distributions
# Set a seed for reproducibility.
using Random
Random.seed!(14);

The Lotka–Volterra Model

The Lotka–Volterra equations, also known as the predator–prey equations, are a pair of first-order nonlinear differential equations. These differential equations are frequently used to describe the dynamics of biological systems in which two species interact, one as a predator and the other as prey. The populations change through time according to the pair of equations

\[ \begin{aligned} \frac{\mathrm{d}x}{\mathrm{d}t} &= (\alpha - \beta y(t))x(t), \\ \frac{\mathrm{d}y}{\mathrm{d}t} &= (\delta x(t) - \gamma)y(t), \end{aligned} \]

where \(x(t)\) and \(y(t)\) denote the populations of prey and predator at time \(t\), respectively, and \(\alpha, \beta, \gamma, \delta\) are positive parameters. In the absence of predators, the prey population \(x\) would increase exponentially at rate \(\alpha\) (with dimensions of time-1). However, the predators kill some prey at a rate \(\beta\) (prey predator-1 time-1), which enables the predator population to increase at rate \(\delta\) (predators prey-1 time-1). Finally, predators are removed by natural mortality at rate \(\gamma\) (time-1).

We implement the Lotka–Volterra model and simulate it with parameters \(\alpha = 1.5\), \(\beta = 1\), \(\gamma = 3\), and \(\delta = 1\) and initial conditions \(x(0) = y(0) = 1\).

# Define Lotka–Volterra model.
function lotka_volterra(du, u, p, t)
    # Model parameters.
    α, β, γ, δ = p
    # Current state.
    x, y = u

    # Evaluate differential equations.
    du[1] =- β * y) * x # prey
    du[2] =* x - γ) * y # predator

    return nothing
end

# Define initial-value problem.
u0 = [1.0, 1.0]
p = [1.5, 1.0, 3.0, 1.0]
tspan = (0.0, 10.0)
prob = ODEProblem(lotka_volterra, u0, tspan, p)

# Plot simulation.
plot(solve(prob, Tsit5()))

We generate noisy observations to use for the parameter estimation tasks in this tutorial. With the saveat argument to the differential equation solver, we specify that the solution is stored only at 0.1 time units.

To make the example more realistic, we generate data as random Poisson counts based on the “true” population densities of predator and prey from the simulation. Poisson-distributed data are common in ecology (for instance, counts of animals detected by a camera trap). We’ll assume that the rate \(\lambda\), which parameterizes the Poisson distribution, is proportional to the underlying animal densities via a constant factor \(q = 1.7\).

sol = solve(prob, Tsit5(); saveat=0.1)
q = 1.7
odedata = rand.(Poisson.(q * Array(sol)))

# Plot simulation and noisy observations.
plot(sol, label=["Prey" "Predator"])
scatter!(sol.t, odedata'; color=[1 2], label="")

An even more realistic example could be fitted to the famous hare-and-lynx system using the long-term trapping records of the Hudson’s Bay Company. A Stan implementation of this problem with slightly different priors can be found here. For this tutorial, though, we will stick with simulated data.

Direct Handling of Bayesian Estimation with Turing

DifferentialEquations.jl is the main Julia package for numerically solving differential equations. Its functionality is completely interoperable with Turing.jl, which means that we can directly simulate differential equations inside a Turing @model.

For the purposes of this tutorial, we choose priors for the parameters that are quite close to the ground truth. As justification, we can imagine we have preexisting estimates for the biological rates. Practically, this helps us to illustrate the results without needing to run overly long MCMC chains.

Note we also have to take special care with the ODE solver. For certain parameter combinations, the numerical solver may predict animal densities that are just barely below zero. This causes errors with the Poisson distribution, which needs a non-negative mean \(\lambda\). To avoid this happening, we tell the solver to aim for small abolute and relative errors (abstol=1e-6, reltol=1e-6). We also add a fudge factor ϵ = 1e-5 to the predicted data. Since ϵ is greater than the solver’s tolerance, it should overcome any remaining numerical error, making sure all predicted values are positive. At the same time, it is so small compared to the data that it should have a negligible effect on inference. If this approach doesn’t work, there are some more ideas to try here.

@model function fitlv(data, prob)
    # Prior distributions.
    α ~ truncated(Normal(1.5, 0.2); lower=0.5, upper=2.5)
    β ~ truncated(Normal(1.1, 0.2); lower=0, upper=2)
    γ ~ truncated(Normal(3.0, 0.2); lower=1, upper=4)
    δ ~ truncated(Normal(1.0, 0.2); lower=0, upper=2)
    q ~ truncated(Normal(1.7, 0.2); lower=0, upper=3)

    # Simulate Lotka–Volterra model. 
    p = [α, β, γ, δ]
    predicted = solve(prob, Tsit5(); p=p, saveat=0.1, abstol=1e-6, reltol=1e-6)
    ϵ = 1e-5
    
    # Observations.
    for i in eachindex(predicted)
        data[:, i] ~ arraydist(Poisson.(q .* predicted[i] .+ ϵ))
    end

    return nothing
end

model = fitlv(odedata, prob)

# Sample 3 independent chains with forward-mode automatic differentiation (the default).
chain = sample(model, NUTS(), MCMCSerial(), 1000, 3; progress=false)
Info: Found initial step size
  ϵ = 0.05
Info: Found initial step size
  ϵ = 0.05
Info: Found initial step size
  ϵ = 0.2
Chains MCMC chain (1000×17×3 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 3
Samples per chain = 1000
Wall duration     = 49.27 seconds
Compute duration  = 46.69 seconds
parameters        = α, β, γ, δ, q
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat       Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

           α    1.5257    0.0611    0.0018   1107.0992    936.2641    1.0028   ⋯
           β    0.9674    0.0674    0.0018   1357.8551   1415.4968    1.0011   ⋯
           γ    3.0136    0.1441    0.0040   1285.9869   1138.0021    1.0019   ⋯
           δ    0.9799    0.0802    0.0025   1076.4832    958.9411    1.0028   ⋯
           q    1.6761    0.0983    0.0027   1298.3271   1141.0664    1.0033   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           α    1.4076    1.4846    1.5249    1.5665    1.6458
           β    0.8434    0.9212    0.9649    1.0122    1.1025
           γ    2.7366    2.9182    3.0111    3.1076    3.3108
           δ    0.8321    0.9244    0.9781    1.0314    1.1473
           q    1.4832    1.6118    1.6724    1.7422    1.8719

The estimated parameters are close to the parameter values the observations were generated with. We can also check visually that the chains have converged.

plot(chain)

Data retrodiction

In Bayesian analysis it is often useful to retrodict the data, i.e. generate simulated data using samples from the posterior distribution, and compare to the original data (see for instance section 3.3.2 - model checking of McElreath’s book “Statistical Rethinking”). Here, we solve the ODE for 300 randomly picked posterior samples in the chain. We plot the ensemble of solutions to check if the solution resembles the data. The 300 retrodicted time courses from the posterior are plotted in gray, the noisy observations are shown as blue and red dots, and the green and purple lines are the ODE solution that was used to generate the data.

plot(; legend=false)
posterior_samples = sample(chain[[:α, :β, :γ, :δ]], 300; replace=false)
for p in eachrow(Array(posterior_samples))
    sol_p = solve(prob, Tsit5(); p=p, saveat=0.1)
    plot!(sol_p; alpha=0.1, color="#BBBBBB")
end

# Plot simulation and noisy observations.
plot!(sol; color=[1 2], linewidth=1)
scatter!(sol.t, odedata'; color=[1 2])

We can see that, even though we added quite a bit of noise to the data the posterior distribution reproduces quite accurately the “true” ODE solution.

Lotka–Volterra model without data of prey

One can also perform parameter inference for a Lotka–Volterra model with incomplete data. For instance, let us suppose we have only observations of the predators but not of the prey. We can fit the model only to the \(y\) variable of the system without providing any data for \(x\):

@model function fitlv2(data::AbstractVector, prob)
    # Prior distributions.
    α ~ truncated(Normal(1.5, 0.2); lower=0.5, upper=2.5)
    β ~ truncated(Normal(1.1, 0.2); lower=0, upper=2)
    γ ~ truncated(Normal(3.0, 0.2); lower=1, upper=4)
    δ ~ truncated(Normal(1.0, 0.2); lower=0, upper=2)
    q ~ truncated(Normal(1.7, 0.2); lower=0, upper=3)

    # Simulate Lotka–Volterra model but save only the second state of the system (predators).
    p = [α, β, γ, δ]
    predicted = solve(prob, Tsit5(); p=p, saveat=0.1, save_idxs=2, abstol=1e-6, reltol=1e-6)
    ϵ = 1e-5

    # Observations of the predators.
    data ~ arraydist(Poisson.(q .* predicted.u .+ ϵ))

    return nothing
end

model2 = fitlv2(odedata[2, :], prob)

# Sample 3 independent chains.
chain2 = sample(model2, NUTS(0.45), MCMCSerial(), 5000, 3; progress=false)
Info: Found initial step size
  ϵ = 0.025
Info: Found initial step size
  ϵ = 0.05
Info: Found initial step size
  ϵ = 0.05
Chains MCMC chain (5000×17×3 Array{Float64, 3}):

Iterations        = 1001:1:6000
Number of chains  = 3
Samples per chain = 5000
Wall duration     = 28.94 seconds
Compute duration  = 28.57 seconds
parameters        = α, β, γ, δ, q
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           α    1.5718    0.0894    0.0080   152.4329   112.4395    1.0344     ⋯
           β    0.9900    0.1035    0.0071   212.4749   500.9032    1.0331     ⋯
           γ    3.0055    0.1490    0.0114   181.6808   150.3449    1.0293     ⋯
           δ    0.9180    0.1204    0.0098   157.9919   112.0723    1.0314     ⋯
           q    1.6494    0.1631    0.0115   200.3701   289.9853    1.0198     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           α    1.4162    1.5117    1.5638    1.6228    1.7840
           β    0.7967    0.9173    0.9897    1.0561    1.2007
           γ    2.6919    2.9107    3.0099    3.1089    3.2791
           δ    0.6684    0.8376    0.9203    0.9953    1.1584
           q    1.3284    1.5393    1.6491    1.7579    1.9751

Again we inspect the trajectories of 300 randomly selected posterior samples.

plot(; legend=false)
posterior_samples = sample(chain2[[:α, :β, :γ, :δ]], 300; replace=false)
for p in eachrow(Array(posterior_samples))
    sol_p = solve(prob, Tsit5(); p=p, saveat=0.1)
    plot!(sol_p; alpha=0.1, color="#BBBBBB")
end

# Plot simulation and noisy observations.
plot!(sol; color=[1 2], linewidth=1)
scatter!(sol.t, odedata'; color=[1 2])

Note that here the observations of the prey (blue dots) were not used in the parameter estimation! Yet, the model can predict the values of \(x\) relatively accurately, albeit with a wider distribution of solutions, reflecting the greater uncertainty in the prediction of the \(x\) values.

Inference of Delay Differential Equations

Here we show an example of inference with another type of differential equation: a delay differential equation (DDE). DDEs are differential equations where derivatives are functions of values at an earlier point in time. This is useful to model a delayed effect, such as the incubation time of a virus.

Here is a delayed version of the Lotka–Volterra system:

\[ \begin{aligned} \frac{\mathrm{d}x}{\mathrm{d}t} &= \alpha x(t-\tau) - \beta y(t) x(t),\\ \frac{\mathrm{d}y}{\mathrm{d}t} &= - \gamma y(t) + \delta x(t) y(t), \end{aligned} \]

where \(\tau\) is a (positive) delay and \(x(t-\tau)\) is the variable \(x\) at an earlier time point \(t - \tau\).

The initial-value problem of the delayed system can be implemented as a DDEProblem. As described in the DDE example, here the function h is the history function that can be used to obtain a state at an earlier time point. Again we use parameters \(\alpha = 1.5\), \(\beta = 1\), \(\gamma = 3\), and \(\delta = 1\) and initial conditions \(x(0) = y(0) = 1\). Moreover, we assume \(x(t) = 1\) for \(t < 0\).

function delay_lotka_volterra(du, u, h, p, t)
    # Model parameters.
    α, β, γ, δ = p

    # Current state.
    x, y = u
    # Evaluate differential equations
    du[1] = α * h(p, t - 1; idxs=1) - β * x * y
    du[2] = -γ * y + δ * x * y

    return nothing
end

# Define initial-value problem.
p = (1.5, 1.0, 3.0, 1.0)
u0 = [1.0; 1.0]
tspan = (0.0, 10.0)
h(p, t; idxs::Int) = 1.0
prob_dde = DDEProblem(delay_lotka_volterra, u0, h, tspan, p);

We generate observations by sampling from the corresponding Poisson distributions derived from the simulation results:

sol_dde = solve(prob_dde; saveat=0.1)
ddedata = rand.(Poisson.(q .* Array(sol_dde)))

# Plot simulation and noisy observations.
plot(sol_dde)
scatter!(sol_dde.t, ddedata'; color=[1 2], label="")

Now we define the Turing model for the Lotka–Volterra model with a delay, and sample 3 independent chains.

@model function fitlv_dde(data, prob)
    # Prior distributions.
    α ~ truncated(Normal(1.5, 0.2); lower=0.5, upper=2.5)
    β ~ truncated(Normal(1.1, 0.2); lower=0, upper=2)
    γ ~ truncated(Normal(3.0, 0.2); lower=1, upper=4)
    δ ~ truncated(Normal(1.0, 0.2); lower=0, upper=2)
    q ~ truncated(Normal(1.7, 0.2); lower=0, upper=3)

    # Simulate Lotka–Volterra model.
    p = [α, β, γ, δ]
    predicted = solve(prob, MethodOfSteps(Tsit5()); p=p, saveat=0.1, abstol=1e-6, reltol=1e-6)
    ϵ = 1e-5

    # Observations.
    for i in eachindex(predicted)
        data[:, i] ~ arraydist(Poisson.(q .* predicted[i] .+ ϵ))
    end
end

model_dde = fitlv_dde(ddedata, prob_dde)

chain_dde = sample(model_dde, NUTS(), MCMCSerial(), 300, 3; progress=false)
Info: Found initial step size
  ϵ = 0.2
Info: Found initial step size
  ϵ = 0.05
Info: Found initial step size
  ϵ = 0.2
Chains MCMC chain (300×17×3 Array{Float64, 3}):

Iterations        = 151:1:450
Number of chains  = 3
Samples per chain = 300
Wall duration     = 12.3 seconds
Compute duration  = 12.05 seconds
parameters        = α, β, γ, δ, q
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           α    1.5849    0.1120    0.0058   384.1059   268.6046    1.0046     ⋯
           β    1.0382    0.0925    0.0037   683.5960   541.1522    1.0011     ⋯
           γ    3.0831    0.1463    0.0061   574.0471   402.3024    1.0021     ⋯
           δ    0.9885    0.0814    0.0042   376.4300   384.6211    1.0086     ⋯
           q    1.6828    0.1178    0.0053   490.7867   566.5574    1.0035     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           α    1.3825    1.5112    1.5745    1.6549    1.8302
           β    0.8799    0.9734    1.0326    1.0943    1.2372
           γ    2.7909    2.9800    3.0853    3.1868    3.3632
           δ    0.8319    0.9342    0.9890    1.0410    1.1479
           q    1.4649    1.6006    1.6831    1.7656    1.9028
plot(chain_dde)

Finally, we plot trajectories of 300 randomly selected samples from the posterior. Again, the dots indicate our observations, the colored lines are the “true” simulations without noise, and the gray lines are trajectories from the posterior samples.

plot(; legend=false)
posterior_samples = sample(chain_dde[[:α, :β, :γ, :δ]], 300; replace=false)
for p in eachrow(Array(posterior_samples))
    sol_p = solve(prob_dde, MethodOfSteps(Tsit5()); p=p, saveat=0.1)
    plot!(sol_p; alpha=0.1, color="#BBBBBB")
end

# Plot simulation and noisy observations.
plot!(sol_dde; color=[1 2], linewidth=1)
scatter!(sol_dde.t, ddedata'; color=[1 2])

The fit is pretty good even though the data was quite noisy to start.

Scaling to Large Models: Adjoint Sensitivities

DifferentialEquations.jl’s efficiency for large stiff models has been shown in multiple benchmarks. To learn more about how to optimize solving performance for stiff problems you can take a look at the docs.

Sensitivity analysis is provided by the SciMLSensitivity.jl package, which forms part of SciML’s differential equation suite. The model sensitivities are the derivatives of the solution with respect to the parameters. Specifically, the local sensitivity of the solution to a parameter is defined by how much the solution would change if the parameter were changed by a small amount. Sensitivity analysis provides a cheap way to calculate the gradient of the solution which can be used in parameter estimation and other optimization tasks. The sensitivity analysis methods in SciMLSensitivity.jl are based on automatic differentiation (AD), and are compatible with many of Julia’s AD backends. More details on the mathematical theory that underpins these methods can be found in the SciMLSensitivity documentation.

To enable sensitivity analysis, you will need to import SciMLSensitivity, and also use one of the AD backends that is compatible with SciMLSensitivity.jl when sampling. For example, if we wanted to use Mooncake.jl, we could run:

import Mooncake
import SciMLSensitivity

# Define the AD backend to use
adtype = AutoMooncake(; config=nothing)

# Sample a single chain with 1000 samples using Mooncake
sample(model, NUTS(; adtype=adtype), 1000; progress=false)
Info: Found initial step size
  ϵ = 0.05
Chains MCMC chain (1000×17×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 290.5 seconds
Compute duration  = 290.5 seconds
parameters        = α, β, γ, δ, q
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           α    1.5273    0.0547    0.0027   404.4859   422.6494    1.0016     ⋯
           β    0.9669    0.0673    0.0029   546.1516   624.4842    1.0010     ⋯
           γ    3.0064    0.1295    0.0058   495.4367   692.5238    1.0012     ⋯
           δ    0.9776    0.0727    0.0038   371.9380   526.1526    1.0007     ⋯
           q    1.6749    0.0985    0.0043   532.7613   566.3005    1.0031     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           α    1.4233    1.4876    1.5264    1.5653    1.6266
           β    0.8414    0.9206    0.9650    1.0103    1.1068
           γ    2.7681    2.9194    3.0033    3.0966    3.2566
           δ    0.8502    0.9259    0.9714    1.0292    1.1346
           q    1.4899    1.6077    1.6701    1.7408    1.8712

In this case, SciMLSensitivity will automatically choose an appropriate sensitivity analysis algorithm for you. You can also manually specify an algorithm by providing the sensealg keyword argument to the solve function; the existing algorithms are covered in this page of the SciMLSensitivity docs.

For more examples of adjoint usage on large parameter models, consult the DiffEqFlux documentation.

Back to top