Getting started: a simple Mixture of Gaussians example
Suppose we have a mixture of Gaussians, e.g. something like
using Distributions
target_distribution = MixtureModel(
Normal,
[(-3, 1.5), (3, 1.5), (20, 1.5)], # parameters
[0.5, 0.3, 0.2] # weights
)
MixtureModel{Distributions.Normal}(K = 3)
components[1] (prior = 0.5000): Distributions.Normal{Float64}(μ=-3.0, σ=1.5)
components[2] (prior = 0.3000): Distributions.Normal{Float64}(μ=3.0, σ=1.5)
components[3] (prior = 0.2000): Distributions.Normal{Float64}(μ=20.0, σ=1.5)
This is a simple 1-dimensional distribution, so let's visualize it:
using StatsPlots
figsize = (800, 400)
plot(target_distribution; components=false, label=nothing, size=figsize)
We can convert a Distribution
from Distributions.jl into something we can pass to sample
for many different samplers by implementing the LogDensityProblems.jl interface:
using LogDensityProblems: LogDensityProblems
struct DistributionLogDensity{D}
d::D
end
LogDensityProblems.logdensity(d::DistributionLogDensity, x) = loglikelihood(d.d, x)
LogDensityProblems.dimension(d::DistributionLogDensity) = length(d.d)
LogDensityProblems.capabilities(::Type{<:DistributionLogDensity}) = LogDensityProblems.LogDensityOrder{0}()
# Wrap our target distribution.
target_model = DistributionLogDensity(target_distribution)
Main.DistributionLogDensity{Distributions.MixtureModel{Distributions.Univariate, Distributions.Continuous, Distributions.Normal, Distributions.Categorical{Float64, Vector{Float64}}}}(MixtureModel{Distributions.Normal}(K = 3)
components[1] (prior = 0.5000): Distributions.Normal{Float64}(μ=-3.0, σ=1.5)
components[2] (prior = 0.3000): Distributions.Normal{Float64}(μ=3.0, σ=1.5)
components[3] (prior = 0.2000): Distributions.Normal{Float64}(μ=20.0, σ=1.5)
)
Metropolis-Hastings (AdvancedMH.jl)
Immediately one might reach for a standard sampler, e.g. a random-walk Metropolis-Hastings (RWMH) from AdvancedMH.jl
and start sampling using sample
:
using AdvancedMH, MCMCChains, LinearAlgebra
using StableRNGs
rng = StableRNG(42) # To ensure reproducbility across devices.
sampler = RWMH(MvNormal(zeros(1), I))
num_iterations = 10_000
chain = sample(
rng,
target_model, sampler, num_iterations;
chain_type=MCMCChains.Chains,
param_names=["x"]
)
Chains MCMC chain (10000×2×1 Array{Float64, 3}):
Iterations = 1:1:10000
Number of chains = 1
Samples per chain = 10000
parameters = x
internals = lp
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat e ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
x -0.5600 3.3066 0.2702 172.8186 678.7867 1.0137 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
x -5.6359 -3.2669 -1.4454 2.5327 5.3961
plot(chain; size=figsize)
This doesn't look quite like what we're expecting.
plot(target_distribution; components=false, linewidth=2)
density!(chain)
plot!(size=figsize)
Notice how chain
has zero probability mass in the left-most component of the mixture!
Let's instead try to use a tempered version of RWMH
. But before we do that, we need to make sure that AdvancedMH.jl is compatible with MCMCTempering.jl.
To do that we need to implement two methods. First we need to tell MCMCTempering how to extract the parameters, and potentially the log-probabilities, from a AdvancedMH.Transition
:
MCMCTempering.getparams_and_logprob
— Functiongetparams_and_logprob([model, ]state)
Return a vector of parameters from the state
.
See also: setparams_and_logprob!!
.
And similarly, we need a way to update the parameters and the log-probabilities of a AdvancedMH.Transition
:
MCMCTempering.setparams_and_logprob!!
— Functionsetparams_and_logprob!!([model, ]state, params)
Set the parameters in the state to params
, possibly mutating if it makes sense.
See also: getparams_and_logprob
.
Luckily, implementing these is quite easy:
using MCMCTempering
MCMCTempering.getparams_and_logprob(transition::AdvancedMH.Transition) = transition.params, transition.lp
function MCMCTempering.setparams_and_logprob!!(transition::AdvancedMH.Transition, params, lp)
return AdvancedMH.Transition(params, lp)
end
Now that this is done, we can wrap sampler
in a MCMCTempering.TemperedSampler
inverse_temperatures = 0.90 .^ (0:20)
sampler_tempered = TemperedSampler(sampler, inverse_temperatures)
TemperedSampler{AdvancedMH.MetropolisHastings{AdvancedMH.RandomWalkProposal{false, Distributions.IsoNormal}}, Vector{Float64}, ReversibleSwap, Nothing}(AdvancedMH.MetropolisHastings{AdvancedMH.RandomWalkProposal{false, Distributions.IsoNormal}}(AdvancedMH.RandomWalkProposal{false, Distributions.IsoNormal}(IsoNormal(
dim: 1
μ: [0.0]
Σ: [1.0;;]
)
)), [1.0, 0.9, 0.81, 0.7290000000000001, 0.6561, 0.5904900000000001, 0.531441, 0.4782969000000001, 0.4304672100000001, 0.3874204890000001 … 0.31381059609000006, 0.2824295364810001, 0.2541865828329001, 0.2287679245496101, 0.20589113209464907, 0.18530201888518416, 0.16677181699666577, 0.15009463529699918, 0.13508517176729928, 0.12157665459056935], ReversibleSwap(), false, nothing)
aaaaand sample
!
chain_tempered = sample(
rng, target_model, sampler_tempered, num_iterations;
chain_type=MCMCChains.Chains,
param_names=["x"]
)
Let's see how this looks
plot(chain_tempered)
plot!(size=figsize)
plot(target_distribution; components=false, linewidth=2)
density!(chain)
density!(chain_tempered)
plot!(size=figsize)
Neato; we've indeed captured the target distribution much better!
We can even inspect all of the tempered chains if we so desire
chain_tempered_all = sample(
rng,
target_model, sampler_tempered, num_iterations;
chain_type=Vector{MCMCChains.Chains}, # Different!
param_names=["x"]
);
plot(target_distribution; components=false, linewidth=2)
density!(chain)
# Tempered ones.
for chain_tempered in chain_tempered_all[2:end]
density!(chain_tempered, color="green", alpha=inv(sqrt(length(chain_tempered_all))))
end
density!(chain_tempered_all[1], color="green", size=figsize)
plot!(size=figsize)
HMC (AdvancedHMC.jl)
We also do this with AdvancedHMC.jl.
using AdvancedHMC: AdvancedHMC
using ForwardDiff: ForwardDiff # for automatic differentation of the logdensity
# Creation of the sampler.
sampler = AdvancedHMC.HMC(0.1, 8)
sampler_tempered = MCMCTempering.TemperedSampler(sampler, inverse_temperatures)
# Sample!
num_iterations = 5_000
chain = sample(
rng,
target_model, sampler, num_iterations;
chain_type=MCMCChains.Chains,
param_names=["x"],
n_adapts=0, # HACK: need this to make AdvancedHMC.jl happy :/
)
plot(chain, size=figsize)
Then if we want to make it work with MCMCTempering, we define the same methods as before:
# Provides a convenient way of "mutating" (read: reconstructing) types with different values
# for specified fields; see usage below.
using Setfield: Setfield
function MCMCTempering.getparams_and_logprob(state::AdvancedHMC.HMCState)
t = state.transition
return t.z.θ, t.z.ℓπ.value
end
function MCMCTempering.setparams_and_logprob!!(model, state::AdvancedHMC.HMCState, params, logprob)
# NOTE: Need to recompute the gradient because it might be used in the next integration step.
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
return Setfield.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian, params, state.transition.z.r;
ℓκ=state.transition.z.ℓκ
)
end
And then, just as before, we can sample
:
chain_tempered_all = sample(
StableRNG(42),
target_model, sampler_tempered, num_iterations;
chain_type=Vector{MCMCChains.Chains},
param_names=["x"],
n_adapts=0, # HACK: need this to make AdvancedHMC.jl happy :/
);
plot(target_distribution; components=false, linewidth=2)
density!(chain)
# Tempered ones.
for chain_tempered in chain_tempered_all[2:end]
density!(chain_tempered, color="green", alpha=inv(sqrt(length(chain_tempered_all))))
end
density!(chain_tempered_all[1], color="green", size=figsize)
plot!(size=figsize)
Works like a charm!
But we're recomputing both the logdensity and the gradient of the logdensity upon every MCMCTempering.setparams_and_logprob!!
above! This seems wholly unnecessary in the tempering case, since
\[\pi_{\beta_1}(x) = \pi(x)^{\beta_1} = \big( \pi(x)^{\beta_2} \big)^{\beta_1 / \beta_2} = \pi_{\beta_2}^{\beta_1 / \beta_2}\]
i.e. if model
in the above is tempered with $\beta_1$ and the params
are coming from a model with $\beta_2$, we can could just compute it as
(β_1 / β_2) * logprob
and similarly for the gradient! Luckily, it's possible to tell MCMCTempering that this should be done by overloading the MCMCTempering.state_from
method. In particular, we'll specify that when we're working with two models of type MCMCTempering.TemperedLogDensityProblem
and two states of type AdvancedHMC.HMCState
, then we can just re-use scale the logdensity and gradient computation from the MCMCTempering.state_from
to get the quantities we want, thus avoiding unnecessary computations:
MCMCTempering.state_from
— Functionstate_from(model_source, state_target, state_source)
state_from(model_source, model_target, state_target, state_source)
Return a new state similar to state_target
but updated from state_source
, which could be a different type of state.
using AbstractMCMC: AbstractMCMC
function MCMCTempering.state_from(
# AdvancedHMC.jl works with `LogDensityModel`, and by default `AbstractMCMC` will wrap
# the input model with `LogDensityModel`, thus asusming it implements the
# LogDensityProblems.jl-interface, by default.
model::AbstractMCMC.LogDensityModel{<:MCMCTempering.TemperedLogDensityProblem},
model_from::AbstractMCMC.LogDensityModel{<:MCMCTempering.TemperedLogDensityProblem},
state::AdvancedHMC.HMCState,
state_from::AdvancedHMC.HMCState,
)
# We'll need the momentum and the kinetic energy from `ze.`
z = state.transition.z
# From this, we'll need everything else.
z_from = state_from.transition.z
params_from = z_from.θ
logprob_from = z_from.ℓπ.value
gradient_from = z_from.ℓπ.gradient
# `logprob` is actually `β * actual_logprob`, and we want it to be `β_from * actual_logprob`, so
# we can compute the "new" logprob as `(β_from / β) * logprob_from`.
beta = model.logdensity.beta
beta_from = model_from.logdensity.beta
delta_beta = beta / beta_from
logprob_new = delta_beta * logprob_from
gradient_new = delta_beta .* gradient_from
# Construct `PhasePoint`. Note that we keep `r` and `ℓκ` from the original state.
return Setfield.@set state.transition.z = AdvancedHMC.PhasePoint(
params_from,
z.r,
AdvancedHMC.DualValue(logprob_new, gradient_new),
z.ℓκ
)
end
For a general model we'd also have to do the same for MCMCTempering.compute_logdensities
if we want to completely eliminate unnecessary computations, but for AbstractMCMC.LogDensity{<:MCMCTempering.TemperedLogDensityProblem}
this is already implemented in MCMCTempering.
Now we can do the same but slightly faster:
chain_tempered_all = sample(
StableRNG(42),
target_model, sampler_tempered, num_iterations;
chain_type=Vector{MCMCChains.Chains},
param_names=["x"],
n_adapts=0, # HACK: need this to make AdvancedHMC.jl happy :/
);
plot(target_distribution; components=false, linewidth=2)
density!(chain)
# Tempered ones.
for chain_tempered in chain_tempered_all[2:end]
density!(chain_tempered, color="green", alpha=inv(sqrt(length(chain_tempered_all))))
end
density!(chain_tempered_all[1], color="green", size=figsize)
plot!(size=figsize)