using Turing
@model function f(N)
m ~ Normal()
X ~ filldist(Normal(m), N)
endf (generic function with 2 methods)
Standard MCMC sampling methods return values of the parameters of the model. However, it is often also useful to generate new data points using the model, given a distribution of the parameters. Turing.jl allows you to do this using the predict function, along with conditioning syntax.
Consider the following simple model, where we observe some normally-distributed data X and want to learn about its mean m.
f (generic function with 2 methods)
Notice first how we have not specified X as an argument to the model. This allows us to use Turing’s conditioning syntax to specify whether we want to provide observed data or not.
If you want to specify X as an argument to the model, then to mark it as being unobserved, you have to instantiate the model again with X = missing or X = fill(missing, N). Whether you use missing or fill(missing, N) depends on whether X is treated as a single distribution (e.g. with filldist or product_distribution), or as multiple independent distributions (e.g. with .~ or a for loop over eachindex(X)). This is rather finicky, so we recommend using the current approach: conditioning and deconditioning X as a whole should work regardless of how X is defined in the model.
┌ Info: Found initial step size └ ϵ = 1.6
2.95531952231239
chain[@varname(m)] now contains samples from the posterior distribution of m. If we use these samples of the parameters to generate new data points, we obtain the posterior predictive distribution. Statistically, this is defined as
\[ p(\tilde{x} | \mathbf{X}) = \int p(\tilde{x} | \theta) p(\theta | \mathbf{X}) d\theta, \]
where \(\tilde{x}\) are the new data which you wish to draw, \(\theta\) are the model parameters, and \(\mathbf{X}\) are the observed data. \(p(\tilde{x} | \theta)\) is the distribution of the new data given the parameters, which is specified in the Turing.jl model (the X ~ ... line); and \(p(\theta | \mathbf{X})\) is the posterior distribution, as given by the Markov chain.
To obtain samples of \(\tilde{x}\), we need to first remove the observed data from the model (or ‘decondition’ the model). This means that when the model is evaluated, it will sample a new value for X. If you don’t decondition the model, then X will remain fixed to the observed data, and no new samples will be generated.
DynamicPPL.Model{typeof(f), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.DefaultContext, false}(f, (N = 5,), NamedTuple(), DynamicPPL.DefaultContext())
If you only want to decondition a single variable X, you can use decondition(model, @varname(X)).
To demonstrate how this deconditioned model can generate new data, we can fix the value of m to be its mean and evaluate the model:
VarNamedTuple
└─ X => [2.0718161169884626, 4.06889638439319, 3.3000141500889035, 1.9389220696198166, 1.1670343415997424]
This has given us a single sample of X given the mean value of m. Of course, to take our Bayesian uncertainty into account, we want to use the full posterior distribution of m, not just its mean. To do so, we use predict, which effectively does the same as above but for every sample in the chain:
╭─FlexiChain (1000 iterations, 1 chain) ───────────────────────────────────────╮ │ ↓ iter = 501:1500 │ │ → chain = 1:1 │ │ │ │ Parameters (2) ── AbstractPPL.VarName │ │ Float64 m │ │ Vector{Float64} X │ │ │ │ Extras (14) │ │ Int64 n_steps, tree_depth │ │ Bool is_accept, numerical_error │ │ Float64 acceptance_rate, log_density, hamiltonian_energy, │ │ hamiltonian_energy_error, max_hamiltonian_energy_error, step_size, │ │ nom_step_size, logprior, loglikelihood, logjoint │ ╰──────────────────────────────────────────────────────────────────────────────╯
predict, like many other Julia functions that involve randomness, takes an optional rng as its first argument. This controls the generation of new X samples, and makes your results reproducible.
predict returns a FlexiChain object itself, which will contain all the original parameters plus the newly predicted variables. If you want to only include the newly predicted variables, you can use predict(rng, predictive_model, chain; include_all=false).
We can visualise the predictive distribution by combining all the samples and making a density plot:
Depending on your data, you may naturally want to create different visualisations. For example, perhaps X contains some time-series data, in which case you can plot each prediction individually as a line against time.
Alternatively, if we use the prior distribution of the parameters \(p(\theta)\), we obtain the prior predictive distribution:
\[ p(\tilde{x}) = \int p(\tilde{x} | \theta) p(\theta) d\theta, \]
In an exactly analogous fashion to above, you could sample from the prior distribution of the conditioned model, and then pass that to predict:
╭─FlexiChain (1000 iterations, 1 chain) ───────────────────────────────────────╮ │ ↓ iter = 1:1000 │ │ → chain = 1:1 │ │ │ │ Parameters (2) ── AbstractPPL.VarName │ │ Float64 m │ │ Vector{Float64} X │ │ │ │ Extras (3) │ │ Float64 logprior, loglikelihood, logjoint │ ╰──────────────────────────────────────────────────────────────────────────────╯
In fact there is a simpler way: you can directly sample from the deconditioned model, using Turing’s Prior sampler. This will, in a single call, generate prior samples for both the parameters as well as the new data.
╭─FlexiChain (1000 iterations, 1 chain) ───────────────────────────────────────╮ │ ↓ iter = 1:1000 │ │ → chain = 1:1 │ │ │ │ Parameters (2) ── AbstractPPL.VarName │ │ Float64 m │ │ Vector{Float64} X │ │ │ │ Extras (3) │ │ Float64 logprior, loglikelihood, logjoint │ ╰──────────────────────────────────────────────────────────────────────────────╯
We can visualise the prior predictive distribution in the same way as before. Let’s compare the two predictive distributions:
We can see here that the prior predictive distribution is:
m (which is 0), rather than the posterior mean (which is close to the true mean of 3).Both of these are because the posterior predictive distribution has been informed by the observed data.