using Turing
using DifferentialEquations
# Load StatsPlots for visualizations and diagnostics.
using StatsPlots
using LinearAlgebra
# Set a seed for reproducibility.
using Random
Random.seed!(14);
Bayesian Estimation of Differential Equations
Most of the scientific community deals with the basic problem of trying to mathematically model the reality around them and this often involves dynamical systems. The general trend to model these complex dynamical systems is through the use of differential equations. Differential equation models often have non-measurable parameters. The popular “forward-problem” of simulation consists of solving the differential equations for a given set of parameters, the “inverse problem” to simulation, known as parameter estimation, is the process of utilizing data to determine these model parameters. Bayesian inference provides a robust approach to parameter estimation with quantified uncertainty.
The Lotka-Volterra Model
The Lotka–Volterra equations, also known as the predator–prey equations, are a pair of first-order nonlinear differential equations. These differential equations are frequently used to describe the dynamics of biological systems in which two species interact, one as a predator and the other as prey. The populations change through time according to the pair of equations
\[ \begin{aligned} \frac{\mathrm{d}x}{\mathrm{d}t} &= (\alpha - \beta y(t))x(t), \\ \frac{\mathrm{d}y}{\mathrm{d}t} &= (\delta x(t) - \gamma)y(t) \end{aligned} \]
where \(x(t)\) and \(y(t)\) denote the populations of prey and predator at time \(t\), respectively, and \(\alpha, \beta, \gamma, \delta\) are positive parameters.
We implement the Lotka-Volterra model and simulate it with parameters \(\alpha = 1.5\), \(\beta = 1\), \(\gamma = 3\), and \(\delta = 1\) and initial conditions \(x(0) = y(0) = 1\).
# Define Lotka-Volterra model.
function lotka_volterra(du, u, p, t)
# Model parameters.
= p
α, β, γ, δ # Current state.
= u
x, y
# Evaluate differential equations.
1] = (α - β * y) * x # prey
du[2] = (δ * x - γ) * y # predator
du[
return nothing
end
# Define initial-value problem.
= [1.0, 1.0]
u0 = [1.5, 1.0, 3.0, 1.0]
p = (0.0, 10.0)
tspan = ODEProblem(lotka_volterra, u0, tspan, p)
prob
# Plot simulation.
plot(solve(prob, Tsit5()))
We generate noisy observations to use for the parameter estimation tasks in this tutorial. With the saveat
argument we specify that the solution is stored only at 0.1
time units. To make the example more realistic we add random normally distributed noise to the simulation.
= solve(prob, Tsit5(); saveat=0.1)
sol = Array(sol) + 0.8 * randn(size(Array(sol)))
odedata
# Plot simulation and noisy observations.
plot(sol; alpha=0.3)
scatter!(sol.t, odedata'; color=[1 2], label="")
Alternatively, we can use real-world data from Hudson’s Bay Company records (an Stan implementation with slightly different priors can be found here: https://mc-stan.org/users/documentation/case-studies/lotka-volterra-predator-prey.html).
Direct Handling of Bayesian Estimation with Turing
Previously, functions in Turing and DifferentialEquations were not inter-composable, so Bayesian inference of differential equations needed to be handled by another package called DiffEqBayes.jl (note that DiffEqBayes works also with CmdStan.jl, Turing.jl, DynamicHMC.jl and ApproxBayes.jl - see the DiffEqBayes docs for more info).
Nowadays, however, Turing and DifferentialEquations are completely composable and we can just simulate differential equations inside a Turing @model
. Therefore, we write the Lotka-Volterra parameter estimation problem using the Turing @model
macro as below:
@model function fitlv(data, prob)
# Prior distributions.
~ InverseGamma(2, 3)
σ ~ truncated(Normal(1.5, 0.5); lower=0.5, upper=2.5)
α ~ truncated(Normal(1.2, 0.5); lower=0, upper=2)
β ~ truncated(Normal(3.0, 0.5); lower=1, upper=4)
γ ~ truncated(Normal(1.0, 0.5); lower=0, upper=2)
δ
# Simulate Lotka-Volterra model.
= [α, β, γ, δ]
p = solve(prob, Tsit5(); p=p, saveat=0.1)
predicted
# Observations.
for i in 1:length(predicted)
:, i] ~ MvNormal(predicted[i], σ^2 * I)
data[end
return nothing
end
= fitlv(odedata, prob)
model
# Sample 3 independent chains with forward-mode automatic differentiation (the default).
= sample(model, NUTS(), MCMCSerial(), 1000, 3; progress=false) chain
┌ Info: Found initial step size
└ ϵ = 0.2
┌ Info: Found initial step size
└ ϵ = 0.00625
┌ Info: Found initial step size
└ ϵ = 0.025
Chains MCMC chain (1000×17×3 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 3
Samples per chain = 1000
Wall duration = 48.83 seconds
Compute duration = 46.37 seconds
parameters = σ, α, β, γ, δ
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat e ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
σ 1.1964 0.5944 0.2414 9.5554 87.3809 1.6648 ⋯
α 1.3860 0.1393 0.0491 9.6636 121.3594 1.6593 ⋯
β 0.9531 0.1008 0.0282 17.0274 87.8577 1.2932 ⋯
γ 2.4557 0.9485 0.3834 9.5906 77.5248 1.6677 ⋯
δ 0.8848 0.2329 0.0918 9.5681 67.9895 1.6663 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
σ 0.7089 0.7649 0.8044 1.9637 2.1821
α 1.0306 1.2962 1.4378 1.4845 1.5525
β 0.7139 0.9096 0.9782 1.0185 1.1029
γ 1.0079 1.1756 3.0197 3.1607 3.3771
δ 0.4823 0.6015 1.0072 1.0569 1.1407
The estimated parameters are close to the parameter values the observations were generated with. We can also check visually that the chains have converged.
plot(chain)
Data retrodiction
In Bayesian analysis it is often useful to retrodict the data, i.e. generate simulated data using samples from the posterior distribution, and compare to the original data (see for instance section 3.3.2 - model checking of McElreath’s book “Statistical Rethinking”). Here, we solve the ODE for 300 randomly picked posterior samples in the chain
. We plot the ensemble of solutions to check if the solution resembles the data. The 300 retrodicted time courses from the posterior are plotted in gray, the noisy observations are shown as blue and red dots, and the green and purple lines are the ODE solution that was used to generate the data.
plot(; legend=false)
= sample(chain[[:α, :β, :γ, :δ]], 300; replace=false)
posterior_samples for p in eachrow(Array(posterior_samples))
= solve(prob, Tsit5(); p=p, saveat=0.1)
sol_p plot!(sol_p; alpha=0.1, color="#BBBBBB")
end
# Plot simulation and noisy observations.
plot!(sol; color=[1 2], linewidth=1)
scatter!(sol.t, odedata'; color=[1 2])