Getting started: a simple Mixture of Gaussians example

Suppose we have a mixture of Gaussians, e.g. something like

using Distributions
target_distribution = MixtureModel(
    Normal,
    [(-3, 1.5), (3, 1.5), (20, 1.5)],  # parameters
    [0.5, 0.3, 0.2]                    # weights
)
MixtureModel{Distributions.Normal}(K = 3)
components[1] (prior = 0.5000): Distributions.Normal{Float64}(μ=-3.0, σ=1.5)
components[2] (prior = 0.3000): Distributions.Normal{Float64}(μ=3.0, σ=1.5)
components[3] (prior = 0.2000): Distributions.Normal{Float64}(μ=20.0, σ=1.5)

This is a simple 1-dimensional distribution, so let's visualize it:

using StatsPlots
figsize = (800, 400)
plot(target_distribution; components=false, label=nothing, size=figsize)

We can convert a Distribution from Distributions.jl into something we can pass to sample for many different samplers by implementing the LogDensityProblems.jl interface:

using LogDensityProblems: LogDensityProblems

struct DistributionLogDensity{D}
    d::D
end

LogDensityProblems.logdensity(d::DistributionLogDensity, x) = loglikelihood(d.d, x)
LogDensityProblems.dimension(d::DistributionLogDensity) = length(d.d)
LogDensityProblems.capabilities(::Type{<:DistributionLogDensity}) = LogDensityProblems.LogDensityOrder{0}()

# Wrap our target distribution.
target_model = DistributionLogDensity(target_distribution)
Main.DistributionLogDensity{Distributions.MixtureModel{Distributions.Univariate, Distributions.Continuous, Distributions.Normal, Distributions.Categorical{Float64, Vector{Float64}}}}(MixtureModel{Distributions.Normal}(K = 3)
components[1] (prior = 0.5000): Distributions.Normal{Float64}(μ=-3.0, σ=1.5)
components[2] (prior = 0.3000): Distributions.Normal{Float64}(μ=3.0, σ=1.5)
components[3] (prior = 0.2000): Distributions.Normal{Float64}(μ=20.0, σ=1.5)
)

Metropolis-Hastings (AdvancedMH.jl)

Immediately one might reach for a standard sampler, e.g. a random-walk Metropolis-Hastings (RWMH) from AdvancedMH.jl and start sampling using sample:

using AdvancedMH, MCMCChains, LinearAlgebra

using StableRNGs
rng = StableRNG(42) # To ensure reproducbility across devices.

sampler = RWMH(MvNormal(zeros(1), I))
num_iterations = 10_000
chain = sample(
    rng,
    target_model, sampler, num_iterations;
    chain_type=MCMCChains.Chains,
    param_names=["x"]
)
Chains MCMC chain (10000×2×1 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 1
Samples per chain = 10000
parameters        = x
internals         = lp

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           x   -0.5600    3.3066    0.2702   172.8186   678.7867    1.0137     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           x   -5.6359   -3.2669   -1.4454    2.5327    5.3961
plot(chain; size=figsize)

This doesn't look quite like what we're expecting.

plot(target_distribution; components=false, linewidth=2)
density!(chain)
plot!(size=figsize)

Notice how chain has zero probability mass in the left-most component of the mixture!

Let's instead try to use a tempered version of RWMH. But before we do that, we need to make sure that AdvancedMH.jl is compatible with MCMCTempering.jl.

To do that we need to implement two methods. First we need to tell MCMCTempering how to extract the parameters, and potentially the log-probabilities, from a AdvancedMH.Transition:

And similarly, we need a way to update the parameters and the log-probabilities of a AdvancedMH.Transition:

Luckily, implementing these is quite easy:

using MCMCTempering

MCMCTempering.getparams_and_logprob(transition::AdvancedMH.Transition) = transition.params, transition.lp
function MCMCTempering.setparams_and_logprob!!(transition::AdvancedMH.Transition, params, lp)
    return AdvancedMH.Transition(params, lp)
end

Now that this is done, we can wrap sampler in a MCMCTempering.TemperedSampler

inverse_temperatures = 0.90 .^ (0:20)
sampler_tempered = TemperedSampler(sampler, inverse_temperatures)
TemperedSampler{AdvancedMH.MetropolisHastings{AdvancedMH.RandomWalkProposal{false, Distributions.IsoNormal}}, Vector{Float64}, ReversibleSwap, Nothing}(AdvancedMH.MetropolisHastings{AdvancedMH.RandomWalkProposal{false, Distributions.IsoNormal}}(AdvancedMH.RandomWalkProposal{false, Distributions.IsoNormal}(IsoNormal(
dim: 1
μ: [0.0]
Σ: [1.0;;]
)
)), [1.0, 0.9, 0.81, 0.7290000000000001, 0.6561, 0.5904900000000001, 0.531441, 0.4782969000000001, 0.4304672100000001, 0.3874204890000001  …  0.3138105960900001, 0.2824295364810001, 0.2541865828329001, 0.22876792454961006, 0.20589113209464907, 0.18530201888518416, 0.16677181699666577, 0.15009463529699918, 0.13508517176729928, 0.12157665459056935], ReversibleSwap(), false, nothing)

aaaaand sample!

chain_tempered = sample(
    rng, target_model, sampler_tempered, num_iterations;
    chain_type=MCMCChains.Chains,
    param_names=["x"]
)
Chains MCMC chain (10000×2×1 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 1
Samples per chain = 10000
parameters        = x
internals         = lp

Summary Statistics
  parameters      mean       std      mcse   ess_bulk    ess_tail      rhat    ⋯
      Symbol   Float64   Float64   Float64    Float64     Float64   Float64    ⋯

           x    2.0678    7.8427    0.4084   645.8715   1244.0113    1.0027    ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           x   -5.4862   -3.1902   -1.1180    3.5764   21.3183

Let's see how this looks

plot(chain_tempered)
plot!(size=figsize)
plot(target_distribution; components=false, linewidth=2)
density!(chain)
density!(chain_tempered)
plot!(size=figsize)

Neato; we've indeed captured the target distribution much better!

We can even inspect all of the tempered chains if we so desire

chain_tempered_all = sample(
    rng,
    target_model, sampler_tempered, num_iterations;
    chain_type=Vector{MCMCChains.Chains},  # Different!
    param_names=["x"]
);
21-element Vector{MCMCChains.Chains{Float64, AxisArrays.AxisArray{Float64, 3, Array{Float64, 3}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, NamedTuple{(:parameters, :internals), Tuple{Vector{Symbol}, Vector{Symbol}}}, NamedTuple{(), Tuple{}}}}:
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 ⋮
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
plot(target_distribution; components=false, linewidth=2)
density!(chain)
# Tempered ones.
for chain_tempered in chain_tempered_all[2:end]
    density!(chain_tempered, color="green", alpha=inv(sqrt(length(chain_tempered_all))))
end
density!(chain_tempered_all[1], color="green", size=figsize)
plot!(size=figsize)

HMC (AdvancedHMC.jl)

We also do this with AdvancedHMC.jl.

using AdvancedHMC: AdvancedHMC
using ForwardDiff: ForwardDiff # for automatic differentation of the logdensity

# Creation of the sampler.
metric = AdvancedHMC.DiagEuclideanMetric(1)
integrator = AdvancedHMC.Leapfrog(0.1)
proposal = AdvancedHMC.StaticTrajectory(integrator, 8)
sampler = AdvancedHMC.HMCSampler(proposal, metric)
sampler_tempered = MCMCTempering.TemperedSampler(sampler, inverse_temperatures)

# Sample!
num_iterations = 5_000
chain = sample(
    rng,
    target_model, sampler, num_iterations;
    chain_type=MCMCChains.Chains,
    param_names=["x"],
)
plot(chain, size=figsize)

Then if we want to make it work with MCMCTempering, we define the same methods as before:

# Provides a convenient way of "mutating" (read: reconstructing) types with different values
# for specified fields; see usage below.
using Setfield: Setfield

function MCMCTempering.getparams_and_logprob(state::AdvancedHMC.HMCState)
    t = state.transition
    return t.z.θ, t.z.ℓπ.value
end

function MCMCTempering.setparams_and_logprob!!(model, state::AdvancedHMC.HMCState, params, logprob)
    # NOTE: Need to recompute the gradient because it might be used in the next integration step.
    hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
    return Setfield.@set state.transition.z = AdvancedHMC.phasepoint(
        hamiltonian, params, state.transition.z.r;
        ℓκ=state.transition.z.ℓκ
    )
end

And then, just as before, we can sample:

chain_tempered_all = sample(
    StableRNG(42),
    target_model, sampler_tempered, num_iterations;
    chain_type=Vector{MCMCChains.Chains},
    param_names=["x"]
);
21-element Vector{MCMCChains.Chains{Float64, AxisArrays.AxisArray{Float64, 3, Array{Float64, 3}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, NamedTuple{(:parameters, :internals), Tuple{Vector{Symbol}, Vector{Symbol}}}, NamedTuple{(), Tuple{}}}}:
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 ⋮
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
plot(target_distribution; components=false, linewidth=2)
density!(chain)
# Tempered ones.
for chain_tempered in chain_tempered_all[2:end]
    density!(chain_tempered, color="green", alpha=inv(sqrt(length(chain_tempered_all))))
end
density!(chain_tempered_all[1], color="green", size=figsize)
plot!(size=figsize)

Works like a charm!

But we're recomputing both the logdensity and the gradient of the logdensity upon every MCMCTempering.setparams_and_logprob!! above! This seems wholly unnecessary in the tempering case, since

\[\pi_{\beta_1}(x) = \pi(x)^{\beta_1} = \big( \pi(x)^{\beta_2} \big)^{\beta_1 / \beta_2} = \pi_{\beta_2}^{\beta_1 / \beta_2}\]

i.e. if model in the above is tempered with $\beta_1$ and the params are coming from a model with $\beta_2$, we can could just compute it as

(β_1 / β_2) * logprob

and similarly for the gradient! Luckily, it's possible to tell MCMCTempering that this should be done by overloading the MCMCTempering.state_from method. In particular, we'll specify that when we're working with two models of type MCMCTempering.TemperedLogDensityProblem and two states of type AdvancedHMC.HMCState, then we can just re-use scale the logdensity and gradient computation from the MCMCTempering.state_from to get the quantities we want, thus avoiding unnecessary computations:

MCMCTempering.state_fromFunction
state_from(model_source, state_target, state_source)
state_from(model_source, model_target, state_target, state_source)

Return a new state similar to state_target but updated from state_source, which could be a different type of state.

source
using AbstractMCMC: AbstractMCMC

function MCMCTempering.state_from(
    # AdvancedHMC.jl works with `LogDensityModel`, and by default `AbstractMCMC` will wrap
    # the input model with `LogDensityModel`, thus asusming it implements the
    # LogDensityProblems.jl-interface, by default.
    model::AbstractMCMC.LogDensityModel{<:MCMCTempering.TemperedLogDensityProblem},
    model_from::AbstractMCMC.LogDensityModel{<:MCMCTempering.TemperedLogDensityProblem},
    state::AdvancedHMC.HMCState,
    state_from::AdvancedHMC.HMCState,
)
    # We'll need the momentum and the kinetic energy from `ze.`
    z = state.transition.z
    # From this, we'll need everything else.
    z_from = state_from.transition.z
    params_from = z_from.θ
    logprob_from = z_from.ℓπ.value
    gradient_from = z_from.ℓπ.gradient

    # `logprob` is actually `β * actual_logprob`, and we want it to be `β_from * actual_logprob`, so
    # we can compute the "new" logprob as `(β_from / β) * logprob_from`.
    beta = model.logdensity.beta
    beta_from = model_from.logdensity.beta
    delta_beta = beta / beta_from
    logprob_new = delta_beta * logprob_from
    gradient_new = delta_beta .* gradient_from

    # Construct `PhasePoint`. Note that we keep `r` and `ℓκ` from the original state.
    return Setfield.@set state.transition.z = AdvancedHMC.PhasePoint(
        params_from,
        z.r,
        AdvancedHMC.DualValue(logprob_new, gradient_new),
        z.ℓκ
    )
end
Note

For a general model we'd also have to do the same for MCMCTempering.compute_logdensities if we want to completely eliminate unnecessary computations, but for AbstractMCMC.LogDensity{<:MCMCTempering.TemperedLogDensityProblem} this is already implemented in MCMCTempering.

Now we can do the same but slightly faster:

chain_tempered_all = sample(
    StableRNG(42),
    target_model, sampler_tempered, num_iterations;
    chain_type=Vector{MCMCChains.Chains},
    param_names=["x"]
);
21-element Vector{MCMCChains.Chains{Float64, AxisArrays.AxisArray{Float64, 3, Array{Float64, 3}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, NamedTuple{(:parameters, :internals), Tuple{Vector{Symbol}, Vector{Symbol}}}, NamedTuple{(), Tuple{}}}}:
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 ⋮
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
 MCMC chain (10000×2×1 Array{Float64, 3})
plot(target_distribution; components=false, linewidth=2)
density!(chain)
# Tempered ones.
for chain_tempered in chain_tempered_all[2:end]
    density!(chain_tempered, color="green", alpha=inv(sqrt(length(chain_tempered_all))))
end
density!(chain_tempered_all[1], color="green", size=figsize)
plot!(size=figsize)