Gaussian Shells

This example will explore the classic Gaussian shells model using Models.GaussianShells.

Setup

For this example, you'll need to add the following packages

julia>]add Distributions MCMCChains Measurements NestedSamplers StatsBase StatsPlots

Define model

using NestedSamplers

model, logz = Models.GaussianShells()

let's take a look at a couple of parameters to see what the likelihood surface looks like

using StatsPlots

x = range(-6, 6, length=1000)
y = range(-6, 6, length=1000)
logf = [model.loglike([xi, yi]) for yi in y, xi in x]
heatmap(
    x, y, exp.(logf),
    aspect_ratio=1,
    xlims=extrema(x),
    ylims=extrema(y),
    xlabel="x",
    ylabel="y",
    size=(400, 400)
)

Sample

using MCMCChains
using StatsBase
# using multi-ellipsoid for bounds
# using default rejection sampler for proposals
sampler = Nested(2, 1000)
chain, state = sample(model, sampler; dlogz=0.05, param_names=["x", "y"])
# resample chain using statistical weights
chain_resampled = sample(chain, Weights(vec(chain[:weights])), length(chain));

Results

chain_resampled
Chains MCMC chain (7072×3×1 Array{Float64, 3}):

Iterations        = 1:7072
Number of chains  = 1
Samples per chain = 7072
parameters        = y, x
internals         = weights

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64

           x    0.0980    3.8160     0.0454    0.0487   6822.1849    1.0001
           y   -0.0012    1.4043     0.0167    0.0153   7340.2303    1.0002

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           x   -5.4828   -3.3625    1.4146    3.7535    5.4787
           y   -2.0464   -1.4215    0.0821    1.3428    2.0243
marginalkde(chain[:x], chain[:y])
density(chain_resampled)
vline!([-5.5, -1.5, 1.5, 5.5], c=:black, ls=:dash, sp=1)
vline!([-2, 2], c=:black, ls=:dash, sp=2)
using Measurements
logz_est = state.logz ± state.logzerr
diff = logz_est - logz
print("logz: ", logz, "\nestimate: ", logz_est, "\ndiff: ", diff)