Bijectors in MCMC

All the above has purely been a mathematical discussion of how distributions can be transformed. Now, we turn to their implementation in Julia, specifically using the Bijectors.jl package.

Bijectors.jl

import Random
Random.seed!(468);

using Distributions: Normal, LogNormal, logpdf
using Statistics: mean, var
using Plots: histogram

A bijection between two sets (Wikipedia) is, essentially, a one-to-one mapping between the elements of these sets. That is to say, if we have two sets \(X\) and \(Y\), then a bijection maps each element of \(X\) to a unique element of \(Y\). To return to our univariate example, where we transformed \(x\) to \(y\) using \(y = \exp(x)\), the exponentiation function is a bijection because every value of \(x\) maps to one unique value of \(y\). The input set (the domain) is \((-\infty, \infty)\), and the output set (the codomain) is \((0, \infty)\). (Here, \((a, b)\) denotes the open interval from \(a\) to \(b\) but excluding \(a\) and \(b\) themselves.)

Since bijections are a one-to-one mapping between elements, we can also reverse the direction of this mapping to create an inverse function. In the case of \(y = \exp(x)\), the inverse function is \(x = \log(y)\).

Note

Technically, the bijections in Bijectors.jl are functions \(f: X \to Y\) for which:

  • \(f\) is continuously differentiable, i.e. the derivative \(\mathrm{d}f(x)/\mathrm{d}x\) exists and is continuous (over the domain of interest \(X\));
  • If \(f^{-1}: Y \to X\) is the inverse of \(f\), then that is also continuously differentiable (over its own domain, i.e. \(Y\)).

The technical mathematical term for this is a diffeomorphism (Wikipedia), but we call them ‘bijectors’.

When thinking about continuous differentiability, it’s important to be conscious of the domains or codomains that we care about. For example, taking the inverse function \(\log(y)\) from above, its derivative is \(1/y\), which is not continuous at \(y = 0\). However, we specified that the bijection \(y = \exp(x)\) maps values of \(x \in (-\infty, \infty)\) to \(y \in (0, \infty)\), so the point \(y = 0\) is not within the domain of the inverse function.

Specifically, one of the primary purposes of Bijectors.jl is to construct bijections which map constrained distributions to unconstrained ones. For example, the log-normal distribution which we saw in the previous page is constrained: its support, i.e. the range over which \(p(x) > 0\), is \((0, \infty)\). However, we can transform that to an unconstrained distribution (the normal distribution) using the transformation \(y = \log(x)\).

Note

Bijectors.jl, as well as DynamicPPL (which we’ll come to later), can work with a much broader class of bijective transformations of variables, not just ones that go to the entire real line. But for the purposes of MCMC, unconstraining is the most common transformation, so we’ll stick with that terminology.

The bijector function, when applied to a distribution, returns a bijection \(f\) that can be used to map the constrained distribution to an unconstrained one. Unsurprisingly, for the log-normal distribution, the bijection is (a broadcasted version of) the \(\log\) function.

import Bijectors as B

f = B.bijector(LogNormal())
(::Base.Fix1{typeof(broadcast), typeof(log)}) (generic function with 1 method)

We can apply this transformation to samples from the original distribution, for example:

samples_lognormal = rand(LogNormal(), 5)

samples_normal = f(samples_lognormal)
5-element Vector{Float64}:
  0.07200886749732066
 -0.07404375655951738
  0.6327762377562545
 -0.9799776018729268
  1.6115229499167665

We can also obtain the inverse of a bijection, \(f^{-1}\):

f_inv = B.inverse(f)

f_inv(samples_normal) == samples_lognormal
true

We know that the transformation \(y = \log(x)\) changes the log-normal distribution to the normal distribution. Bijectors.jl also gives us a way to access that transformed distribution:

transformed_dist = B.transformed(LogNormal(), f)
Bijectors.UnivariateTransformed{Distributions.LogNormal{Float64}, Base.Fix1{typeof(broadcast), typeof(log)}}(
dist: Distributions.LogNormal{Float64}(μ=0.0, σ=1.0)
transform: Base.Fix1{typeof(broadcast), typeof(log)}(broadcast, log)
)

This type doesn’t immediately look like a Normal(), but it behaves in exactly the same way. For example, we can sample from it and plot a histogram:

samples_plot = rand(transformed_dist, 5000)
histogram(samples_plot, bins=50)

We can also obtain the logpdf of the transformed distribution and check that it is the same as that of a normal distribution:

println("Sample:   $(samples_plot[1])")
println("Expected: $(logpdf(Normal(), samples_plot[1]))")
println("Actual:   $(logpdf(transformed_dist, samples_plot[1]))")
Sample:   -0.2031149013821452
Expected: -0.9395663647864121
Actual:   -0.9395663647864121

Given the discussion in the previous sections, you might not be surprised to find that the logpdf of the transformed distribution is implemented using the Jacobian of the transformation. In particular, it directly uses the formula

\[\log(q(\mathbf{y})) = \log(p(\mathbf{x})) - \log(|\det(\mathbf{J})|).\]

You can access \(\log(|\det(\mathbf{J})|)\) (evaluated at the point \(\mathbf{x}\)) using the logabsdetjac function:

# Reiterating the setup, just to be clear
original_dist = LogNormal()
x = rand(original_dist)
f = B.bijector(original_dist)
y = f(x)
transformed_dist = B.transformed(LogNormal(), f)

println("log(q(y))     : $(logpdf(transformed_dist, y))")
println("log(p(x))     : $(logpdf(original_dist, x))")
println("log(|det(J)|) : $(B.logabsdetjac(f, x))")
log(q(y))     : -0.9258400203646245
log(p(x))     : -0.8083539602557612
log(|det(J)|) : 0.11748606010886327

from which you can see that the equation above holds. There are more functions available in the Bijectors.jl API; for full details do check out the documentation. For example, logpdf_with_trans can directly give us \(\log(q(\mathbf{y}))\) without going through the effort of constructing the bijector:

B.logpdf_with_trans(original_dist, x, true)
-0.9258400203646245

The case for bijectors in MCMC

Constraints pose a challenge for many numerical methods such as optimisation, and sampling is no exception to this. The problem is that for any value \(x\) outside of the support of a constrained distribution, \(p(x)\) will be zero, and the logpdf will be \(-\infty\). Thus, any term that involves some ratio of probabilities (or equivalently, the logpdf) will be infinite.

Metropolis with rejection

To see the practical impact of this on sampling, let’s attempt to sample from a log-normal distribution using a random walk Metropolis algorithm.

One way of handling constraints is to simply reject any steps that would take us out of bounds. This is a barebones implementation which does precisely that:

# Take a step where the proposal is a normal distribution centred around
# the current value. Return the new value, plus a flag to indicate whether
# the new value was in bounds.
function mh_step(logp, x, in_bounds)
    x_proposed = rand(Normal(x, 1))
    in_bounds(x_proposed) || return (x, false)  # bounds check
    acceptance_logp = logp(x_proposed) - logp(x)
    return if log(rand()) < acceptance_logp
        (x_proposed, true)  # successful step
    else
        (x, true)  # failed step
    end
end

# Run a random walk Metropolis sampler.
# `logp`      : a function that takes `x` and returns the log pdf of the
#               distribution we're trying to sample from (up to a constant
#               additive factor)
# `n_samples` : the number of samples to draw
# `in_bounds` : a function that takes `x` and returns whether `x` is within
#               the support of the distribution
# `x0`        : the initial value
# Returns a vector of samples, plus the number of times we went out of bounds.
function mh(logp, n_samples, in_bounds; x0=1.0)
    samples = [x0]
    x = x0
    n_out_of_bounds = 0
    for _ in 2:n_samples
        x, inb = mh_step(logp, x, in_bounds)
        if !inb
            n_out_of_bounds += 1
        end
        push!(samples, x)
    end
    return (samples, n_out_of_bounds)
end
mh (generic function with 1 method)
Note

In the MH algorithm, we technically do not need to explicitly check the proposal, because for any \(x \leq 0\), we have that \(p(x) = 0\); thus, the acceptance probability will be zero. However, doing so here allows us to track how often this happens, and also illustrates the general principle of handling constraints by rejection.

Now to actually perform the sampling:

logp(x) = logpdf(LogNormal(), x)
samples, n_out_of_bounds = mh(logp, 10000, x -> x > 0)
histogram(samples, bins=0:0.1:5; xlims=(0, 5))

How do we know that this has sampled correctly? For one, we can check that the mean of the samples are what we expect them to be. From Wikipedia, the mean of a log-normal distribution is given by \(\exp[\mu + (\sigma^2/2)]\). For our log-normal distribution, we set \(\mu = 0\) and \(\sigma = 1\), so:

println("expected mean: $(exp(0 + (1^2/2)))")
println("  actual mean: $(mean(samples))")
expected mean: 1.6487212707001282
  actual mean: 1.3347941996487

Metropolis with transformation

The issue with this is that many of the sampling steps are unproductive, in that they bring us to the region of \(x \leq 0\) and get rejected:

println("went out of bounds $n_out_of_bounds/10000 times")
went out of bounds 1870/10000 times

And this could have been even worse if we had chosen a wider proposal distribution in the Metropolis step, or if the support of the distribution was narrower! In general, we probably don’t want to have to re-parameterise our proposal distribution each time we sample from a distribution with different constraints.

This is where the transformation functions from Bijectors.jl come in: we can use them to map the distribution to an unconstrained one and sample from that instead. Since the sampler only ever sees an unconstrained distribution, it doesn’t have to worry about checking for bounds.

To make this happen, instead of passing \(\log(p(x))\) to the sampler, we pass \(\log(q(y))\). This can be obtained using the Bijectors.logpdf_with_trans function that was introduced above.

d = LogNormal()
f = B.bijector(d)     # Transformation function
f_inv = B.inverse(f)  # Inverse transformation function
function logq(y)
    x = f_inv(y)
    return B.logpdf_with_trans(d, x, true)
end
samples_transformed, n_oob_transformed = mh(logq, 10000, x -> true);

Now, this process gives us samples that have been transformed, so we need to un-transform them to get the samples from the original distribution:

samples_untransformed = f_inv(samples_transformed)
histogram(samples_untransformed, bins=0:0.1:5; xlims=(0, 5))

We can check the mean of the samples too, to see that it is what we expect:

println("expected mean: $(exp(0 + (1^2/2)))")
println("  actual mean: $(mean(samples_untransformed))")
expected mean: 1.6487212707001282
  actual mean: 1.7184757306010636

On top of that, we can also verify that we don’t ever go out of bounds:

println("went out of bounds $n_oob_transformed/10000 times")
went out of bounds 0/10000 times

Which one is better?

In the subsections above, we’ve seen two different methods of sampling from a constrained distribution:

  1. Sample directly from the distribution and reject any samples outside of its support.
  2. Transform the distribution to an unconstrained one and sample from that instead.

(Note that both of these methods are applicable to other samplers as well, such as Hamiltonian Monte Carlo.)

Of course, a natural question to then ask is which one of these is better!

One option might be look at the sample means above to see which one is ‘closer’ to the expected mean. However, that’s not a very robust method because the sample mean is itself random, and if we were to use a different random seed we might well reach a different conclusion.

Another possibility we could look at the number of times the sample was rejected. Does a lower rejection rate (as in the transformed case) imply that the method is better? As it happens, this might seem like an intuitive conclusion, but it’s not necessarily the case: for example, the sampling in unconstrained space could be much less efficient, such that even though we’re not rejecting samples, the ones that we do get are overly correlated and thus not representative of the distribution.

A robust comparison would involve performing both methods many times and seeing how reliable the sample mean is.

function get_sample_mean(; transform)
    if transform
       # Sample from transformed distribution
       samples = f_inv(first(mh(logq, 10000, x -> true)))
    else
       # Sample from original distribution and reject if out of bounds
       samples = first(mh(logp, 10000, x -> x > 0))
    end
    return mean(samples)
end
get_sample_mean (generic function with 1 method)
means_with_rejection = [get_sample_mean(; transform=false) for _ in 1:1000]
mean(means_with_rejection), var(means_with_rejection)
(1.652032684314151, 0.30454613712270745)
means_with_transformation = [get_sample_mean(; transform=true) for _ in 1:1000]
mean(means_with_transformation), var(means_with_transformation)
(1.6489347143276902, 0.003945513418875533)

We can see from this small study that although both methods give us the correct mean (on average), the method with the transformation is more reliable, in that the variance is much lower!

Note

Alternatively, we could also try to directly measure how correlated the samples are. One way to do this is to calculate the effective sample size (ESS), which is described in the Stan documentation, and implemented in MCMCChains.jl. A larger ESS implies that the samples are less correlated, and thus more representative of the underlying distribution:

using MCMCChains: Chains, ess

rejection = first(mh(logp, 10000, x -> x > 0))
transformation = f_inv(first(mh(logq, 10000, x -> true)))
chn = Chains(hcat(rejection, transformation), [:rejection, :transformation])
ess(chn)
ESS
      parameters         ess   ess_per_sec
          Symbol     Float64       Missing

       rejection    503.4349       missing
  transformation   1106.6909       missing

What happens without the Jacobian?

In the transformation method above, we used Bijectors.logpdf_with_trans to calculate the log probability density of the transformed distribution. This function makes sure to include the Jacobian term when performing the transformation, and this is what makes sure that when we un-transform the samples, we get the correct distribution.

The next code block shows what happens if we don’t include the Jacobian term. In this logq_wrong, we’ve un-transformed y to x and calculated the logpdf with respect to its original distribution. This is exactly the same mistake that we made at the start of this article with naive_logpdf.

function logq_wrong(y)
    x = f_inv(y)
    return logpdf(d, x)  # no Jacobian term!
end
samples_questionable, _ = mh(logq_wrong, 100000, x -> true)
samples_questionable_untransformed = f_inv(samples_questionable)

println("mean: $(mean(samples_questionable_untransformed))")
mean: 0.5919166187308191

You can see that even though we used ten times more samples, the mean is quite wrong, which implies that our samples are not being drawn from the correct distribution.

In the next page, we’ll see how to use these transformations in the context of a probabilistic programming language, paying particular attention to their handling in DynamicPPL.

Back to top