Generated Quantities

Often, the most natural parameterization for a model is not the most computationally feasible. Consider the following (efficiently reparametrized) implementation of Neal’s funnel (Neal, 2003):

using Turing

@model function Neal()
    # Raw draws
    y_raw ~ Normal(0, 1)
    x_raw ~ arraydist([Normal(0, 1) for i in 1:9])

    # Transform:
    y = 3 * y_raw
    x = exp.(y ./ 2) .* x_raw

    # Return:
    return [x; y]
end
Neal (generic function with 2 methods)

In this case, the random variables exposed in the chain (x_raw, y_raw) are not in a helpful form — what we’re after are the deterministically transformed variables x and y.

More generally, there are often quantities in our models that we might be interested in viewing, but which are not explicitly present in our chain.

We can generate draws from these variables — in this case, x and y — by adding them as a return statement to the model, and then calling generated_quantities(model, chain). Calling this function outputs an array of values specified in the return statement of the model.

For example, in the above reparametrization, we sample from our model:

chain = sample(Neal(), NUTS(), 1000; progress=false)
┌ Info: Found initial step size
└   ϵ = 1.6
Chains MCMC chain (1000×22×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 7.7 seconds
Compute duration  = 7.7 seconds
parameters        = y_raw, x_raw[1], x_raw[2], x_raw[3], x_raw[4], x_raw[5], x_raw[6], x_raw[7], x_raw[8], x_raw[9]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk   ess_tail      rhat    ⋯
      Symbol   Float64   Float64   Float64     Float64    Float64   Float64    ⋯

       y_raw   -0.0266    1.0365    0.0302   1178.4275   738.0395    1.0016    ⋯
    x_raw[1]    0.0226    0.9054    0.0260   1216.8562   740.9310    1.0030    ⋯
    x_raw[2]    0.0142    0.9815    0.0259   1451.3248   719.8028    1.0013    ⋯
    x_raw[3]   -0.0218    0.9990    0.0292   1178.5383   500.0843    1.0018    ⋯
    x_raw[4]    0.0051    1.0298    0.0274   1409.3547   797.6074    0.9995    ⋯
    x_raw[5]    0.0084    0.9969    0.0290   1191.7394   613.9673    1.0019    ⋯
    x_raw[6]    0.0034    1.0223    0.0326    985.1660   643.0315    0.9992    ⋯
    x_raw[7]    0.0151    0.9708    0.0292   1072.3654   674.5299    1.0016    ⋯
    x_raw[8]    0.0549    0.9618    0.0297   1055.5728   636.2819    1.0033    ⋯
    x_raw[9]   -0.0092    1.0109    0.0293   1184.3514   782.6089    1.0020    ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

       y_raw   -2.0042   -0.7204   -0.0233    0.6565    1.9917
    x_raw[1]   -1.6619   -0.5895    0.0244    0.6384    1.7528
    x_raw[2]   -1.9753   -0.6036   -0.0214    0.6376    2.0844
    x_raw[3]   -1.9642   -0.6934   -0.0258    0.6486    2.0028
    x_raw[4]   -2.0200   -0.7175    0.0014    0.7193    2.0043
    x_raw[5]   -2.0129   -0.6410    0.0173    0.7030    1.9352
    x_raw[6]   -2.0584   -0.6739    0.0128    0.7097    1.8516
    x_raw[7]   -1.8415   -0.6244    0.0103    0.6643    1.8364
    x_raw[8]   -1.8285   -0.6335    0.0222    0.7226    2.0290
    x_raw[9]   -2.0033   -0.7039    0.0005    0.6568    1.9939

Notice that only x_raw and y_raw are stored in the chain; x and y are not because they do not appear on the left-hand side of a tilde-statement.

To get x and y, we can then call:

generated_quantities(Neal(), chain)
1000×1 Matrix{Vector{Float64}}:
 [1.1333440506074053, -9.716589782014701, 0.1220728541108832, -0.37761915827653186, -4.587557603951678, -2.4645747067100867, 9.355886421954846, 4.062817155086687, -1.3677492569316994, 3.710185656219256]
 [0.030124403641561926, 0.08917331742765595, -0.0033660177301165816, -0.06798533929987156, 0.06170142141429877, 0.020659730026155745, 0.09752272188669779, 0.05047039946332551, -0.08060160102182413, -4.919286557818886]
 [-0.03760116658765476, 0.07938605785714685, -0.01788868336924059, -0.06937054905303298, 0.11124786990470661, 0.03411436063986355, 0.0826051128212285, 0.09861874846826409, -0.0411837897009423, -4.819602216406156]
 [-0.09112583382986994, -0.01167882521206143, 0.5363993260631748, 0.910399584760842, -0.9029638607203172, -0.22367948214969177, -1.3080939206725954, 0.17092873967481634, -0.36378468267234204, -0.6590729196137506]
 [0.4135669155821186, -1.6795650440482024, -1.3875847541299784, 1.599202797546467, -0.40959948434792637, 1.7428171369687748, -1.4361703539142636, 1.0026571552838672, 0.6429211096712572, 0.7647902751227772]
 [0.0725945500686418, 0.6154550753653741, 0.37673214207237266, -0.5817391202030268, 0.1296243886089064, -0.6743660879463471, 0.5052642387139707, -0.38649250600403495, -0.3289389907558548, -1.0365134182313707]
 [0.050393379285307335, -0.2202531675396907, -0.7202015877439335, -0.6851911014976367, 0.12803831975123983, -1.125178852676807, 0.4704758722923616, 0.06738788808517007, -0.2313756004234985, -0.6651075676309219]
 [0.050393379285307335, -0.2202531675396907, -0.7202015877439335, -0.6851911014976367, 0.12803831975123983, -1.125178852676807, 0.4704758722923616, 0.06738788808517007, -0.2313756004234985, -0.6651075676309219]
 [0.050393379285307335, -0.2202531675396907, -0.7202015877439335, -0.6851911014976367, 0.12803831975123983, -1.125178852676807, 0.4704758722923616, 0.06738788808517007, -0.2313756004234985, -0.6651075676309219]
 [-0.605710810715987, -0.7582195585884293, 2.379119795521339, 1.0507849915561212, 0.675037640856558, 3.0955272115564876, 0.5977540004577713, -2.064542665699323, 4.082573022573238, 1.4025407011124487]
 ⋮
 [0.014882644329769277, -0.010796321586528414, -0.004331704574606611, -0.03417922378483356, 0.03842548379005371, -0.04892400495633194, -0.011117214709147477, 0.019810320018036127, 0.03372837695180565, -6.900886160927007]
 [-0.038452314543193006, 0.04962214052395659, -0.07363078441458971, 0.13288381765934232, -0.011785383271629118, -0.08430828057245937, 0.006880344142046319, 0.17885029272441996, 0.04941234715277995, -5.107012222754822]
 [-4.5783404256983715, -1.7336365149823723, 3.1836901087562666, -2.078367943586303, 2.668761278122588, 1.7635708732073432, 0.9952091194139531, 1.4565127515902476, 0.9587127892965208, 2.5750261987390406]
 [0.0455875992003101, 0.11008139119668042, -0.14471052570440288, 0.293868881626534, -0.1404053598225744, -0.01948886683706396, -0.1206578622716115, 0.020715585195668268, -0.06441293594476513, -3.229453790356666]
 [-2.588161686714333, -0.5827521147783218, 1.0080185852114196, -0.5214692931848705, 1.9047407718468543, -0.2921544351999813, -0.8967138489607168, 2.4296816733298483, 2.2777393052007895, 0.7557929364275997]
 [1.0295540720224519, 0.2373634420568055, -0.2851396227807864, 0.054715968441750175, -0.5621026184518748, 0.14388477701864028, 0.29867132112709155, -0.7294579181938533, -0.806865253711678, -1.5879061295904555]
 [-43.19130912629576, -27.04957888114907, 18.289840320764938, -2.879533077549633, 11.922722378947919, 6.913290603870871, -5.479712788652368, 33.346940140972414, 23.557754232387534, 5.5566727954444115]
 [-0.015044307083566357, 0.30653439859881776, 0.183274162414087, 0.22956595056336526, -0.15605804509848784, 0.5430880420418422, -0.08697261605634724, -0.014455347852813214, -0.12600762908414748, -2.1359868096502295]
 [3.889761467362819, -12.866917683022812, -6.135591336706601, -7.850478073971375, 1.237692784604818, -1.5019723155426972, 3.8951470692111982, 2.669987801055438, 7.174841196979786, 3.7012899144195543]

Each element of this corresponds to an array with the values of x1, x2, ..., x9, y for each posterior sample.

In this case, it might be useful to reorganize our output into a matrix for plotting:

reparam_chain = reduce(hcat, generated_quantities(Neal(), chain))'
1000×10 adjoint(::Matrix{Float64}) with eltype Float64:
   1.13334     -9.71659     0.122073    …   4.06282    -1.36775     3.71019
   0.0301244    0.0891733  -0.00336602      0.0504704  -0.0806016  -4.91929
  -0.0376012    0.0793861  -0.0178887       0.0986187  -0.0411838  -4.8196
  -0.0911258   -0.0116788   0.536399        0.170929   -0.363785   -0.659073
   0.413567    -1.67957    -1.38758         1.00266     0.642921    0.76479
   0.0725946    0.615455    0.376732    …  -0.386493   -0.328939   -1.03651
   0.0503934   -0.220253   -0.720202        0.0673879  -0.231376   -0.665108
   0.0503934   -0.220253   -0.720202        0.0673879  -0.231376   -0.665108
   0.0503934   -0.220253   -0.720202        0.0673879  -0.231376   -0.665108
  -0.605711    -0.75822     2.37912        -2.06454     4.08257     1.40254
   ⋮                                    ⋱                          
   0.0148826   -0.0107963  -0.0043317       0.0198103   0.0337284  -6.90089
  -0.0384523    0.0496221  -0.0736308       0.17885     0.0494123  -5.10701
  -4.57834     -1.73364     3.18369         1.45651     0.958713    2.57503
   0.0455876    0.110081   -0.144711        0.0207156  -0.0644129  -3.22945
  -2.58816     -0.582752    1.00802     …   2.42968     2.27774     0.755793
   1.02955      0.237363   -0.28514        -0.729458   -0.806865   -1.58791
 -43.1913     -27.0496     18.2898         33.3469     23.5578      5.55667
  -0.0150443    0.306534    0.183274       -0.0144553  -0.126008   -2.13599
   3.88976    -12.8669     -6.13559         2.66999     7.17484     3.70129

from which we can recover a vector of our samples:

x1_samples = reparam_chain[:, 1]
y_samples = reparam_chain[:, 10]
1000-element Vector{Float64}:
  3.710185656219256
 -4.919286557818886
 -4.819602216406156
 -0.6590729196137506
  0.7647902751227772
 -1.0365134182313707
 -0.6651075676309219
 -0.6651075676309219
 -0.6651075676309219
  1.4025407011124487
  ⋮
 -6.900886160927007
 -5.107012222754822
  2.5750261987390406
 -3.229453790356666
  0.7557929364275997
 -1.5879061295904555
  5.5566727954444115
 -2.1359868096502295
  3.7012899144195543
Back to top