Univariate ADVI example
But the real utility of TransformedDistribution
becomes more apparent when using transformed(dist, b)
for any bijector b
. To get the transformed distribution corresponding to the Beta(2, 2)
, we called transformed(dist)
before. This is simply an alias for transformed(dist, bijector(dist))
. Remember bijector(dist)
returns the constrained-to-constrained bijector for that particular Distribution
. But we can of course construct a TransformedDistribution
using different bijectors with the same dist
. This is particularly useful in something called Automatic Differentiation Variational Inference (ADVI).[2] An important part of ADVI is to approximate a constrained distribution, e.g. Beta
, as follows:
- Sample
x
from aNormal
with parametersμ
andσ
, i.e.x ~ Normal(μ, σ)
. - Transform
x
toy
s.t.y ∈ support(Beta)
, with the transform being a differentiable bijection with a differentiable inverse (a "bijector")
This then defines a probability density with same support as Beta
! Of course, it's unlikely that it will be the same density, but it's an approximation. Creating such a distribution becomes trivial with Bijector
and TransformedDistribution
:
julia> using StableRNGs: StableRNG
julia> rng = StableRNG(42);
julia> dist = Beta(2, 2)
Beta{Float64}(α=2.0, β=2.0)
julia> b = bijector(dist) # (0, 1) → ℝ
Bijectors.Logit{Float64, Float64}(0.0, 1.0)
julia> b⁻¹ = inverse(b) # ℝ → (0, 1)
Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0))
julia> td = transformed(Normal(), b⁻¹) # x ∼ 𝓝(0, 1) then b(x) ∈ (0, 1)
UnivariateTransformed{Normal{Float64}, Inverse{Bijectors.Logit{Float64, Float64}}}( dist: Normal{Float64}(μ=0.0, σ=1.0) transform: Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0)) )
julia> x = rand(rng, td) # ∈ (0, 1)
0.3384404850130036
It's worth noting that support(Beta)
is the closed interval [0, 1]
, while the constrained-to-unconstrained bijection, Logit
in this case, is only well-defined as a map (0, 1) → ℝ
for the open interval (0, 1)
. This is of course not an implementation detail. ℝ
is itself open, thus no continuous bijection exists from a closed interval to ℝ
. But since the boundaries of a closed interval has what's known as measure zero, this doesn't end up affecting the resulting density with support on the entire real line. In practice, this means that
julia> td = transformed(Beta())
UnivariateTransformed{Beta{Float64}, Bijectors.Logit{Float64, Float64}}( dist: Beta{Float64}(α=1.0, β=1.0) transform: Bijectors.Logit{Float64, Float64}(0.0, 1.0) )
julia> inverse(td.transform)(rand(rng, td))
0.8130302707446476
will never result in 0
or 1
though any sample arbitrarily close to either 0
or 1
is possible. Disclaimer: numerical accuracy is limited, so you might still see 0
and 1
if you're lucky.
Multivariate ADVI example
We can also do multivariate ADVI using the Stacked
bijector. Stacked
gives us a way to combine univariate and/or multivariate bijectors into a singe multivariate bijector. Say you have a vector x
of length 2 and you want to transform the first entry using Exp
and the second entry using Log
. Stacked
gives you an easy and efficient way of representing such a bijector.
julia> using Bijectors: SimplexBijector # Original distributions
julia> dists = (Beta(), InverseGamma(), Dirichlet(2, 3)); # Construct the corresponding ranges
julia> ranges = [];
julia> idx = 1;
julia> for i in 1:length(dists) d = dists[i] push!(ranges, idx:(idx + length(d) - 1)) global idx idx += length(d) end;
julia> ranges # Base distribution; mean-field normal
3-element Vector{Any}: 1:1 2:2 3:4
julia> num_params = ranges[end][end]
4
julia> d = MvNormal(zeros(num_params), ones(num_params)); # Construct the transform
julia> bs = bijector.(dists); # constrained-to-unconstrained bijectors for dists
julia> ibs = inverse.(bs); # invert, so we get unconstrained-to-constrained
julia> sb = Stacked(ibs, ranges) # => Stacked <: Bijector # Mean-field normal with unconstrained-to-constrained stacked bijector
Stacked(Any[Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0)), Base.Fix1{typeof(broadcast), typeof(exp)}(broadcast, exp), Inverse{Bijectors.SimplexBijector}(Bijectors.SimplexBijector())], Any[1:1, 2:2, 3:4], Any[1:1, 2:2, 3:5])
julia> td = transformed(d, sb);
julia> y = rand(td)
5-element Vector{Float64}: 0.3215382466472137 0.9473799956569682 0.48861355889563257 0.1680913323837635 0.343295108720604
julia> 0.0 ≤ y[1] ≤ 1.0
true
julia> 0.0 < y[2]
true
julia> sum(y[3:4]) ≈ 1.0
false
Normalizing flows
A very interesting application is that of normalizing flows.[1] Usually this is done by sampling from a multivariate normal distribution, and then transforming this to a target distribution using invertible neural networks. Currently there are two such transforms available in Bijectors.jl: PlanarLayer
and RadialLayer
. Let's create a flow with a single PlanarLayer
:
julia> d = MvNormal(zeros(2), ones(2));
julia> b = PlanarLayer(2)
PlanarLayer(w = [-0.4151791683548836, -0.33568469654156496], u = [-1.1340279772763475, -0.11545936462823485], b = [-1.9025651161453447])
julia> flow = transformed(d, b)
MultivariateTransformed{DiagNormal, PlanarLayer{Vector{Float64}, Vector{Float64}}}( dist: DiagNormal( dim: 2 μ: [0.0, 0.0] Σ: [1.0 0.0; 0.0 1.0] ) transform: PlanarLayer(w = [-0.4151791683548836, -0.33568469654156496], u = [-1.1340279772763475, -0.11545936462823485], b = [-1.9025651161453447]) )
julia> flow isa MultivariateDistribution
true
That's it. Now we can sample from it using rand
and compute the logpdf
, like any other Distribution
.
julia> y = rand(rng, flow)
2-element Vector{Float64}: -0.3277538604681132 -0.03258437137299325
julia> logpdf(flow, y) # uses inverse of `b`
-2.1602841038553984
Similarily to the multivariate ADVI example, we could use Stacked
to get a bounded flow:
julia> d = MvNormal(zeros(2), ones(2));
julia> ibs = inverse.(bijector.((InverseGamma(2, 3), Beta())));
julia> sb = Stacked(ibs) # == Stacked(ibs, [i:i for i = 1:length(ibs)]
Stacked((Base.Fix1{typeof(broadcast), typeof(exp)}(broadcast, exp), Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0))), (1:1, 2:2), (1:1, 2:2))
julia> b = sb ∘ PlanarLayer(2)
Stacked((Base.Fix1{typeof(broadcast), typeof(exp)}(broadcast, exp), Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0))), (1:1, 2:2), (1:1, 2:2)) ∘ PlanarLayer(w = [1.1082421550008306, -0.4686743031035719], u = [0.508401208171964, 0.07196473212856534], b = [-0.2512548097012267])
julia> td = transformed(d, b);
julia> y = rand(rng, td)
2-element Vector{Float64}: 4.1776623545381115 0.8101927687536546
julia> 0 < y[1]
true
julia> 0 ≤ y[2] ≤ 1
true
Want to fit the flow?
julia> using Zygote # Construct the flow.
julia> b = PlanarLayer(2) # Convenient for extracting parameters and reconstructing the flow.
PlanarLayer(w = [-0.4654746854019659, -0.13432731094945471], u = [0.6174684944807828, 2.084730608200266], b = [1.3256168425690136])
julia> using Functors
julia> θs, reconstruct = Functors.functor(b); # Make the objective a `struct` to avoid capturing global variables.
julia> struct NLLObjective{R,D,T} reconstruct::R basedist::D data::T end
julia> function (obj::NLLObjective)(θs) transformed_dist = transformed(obj.basedist, obj.reconstruct(θs)) return -sum(Base.Fix1(logpdf, transformed_dist), eachcol(obj.data)) end # Some random data to estimate the density of.
julia> xs = randn(2, 1000); # Construct the objective.
julia> f = NLLObjective(reconstruct, MvNormal(2, 1), xs); # Initial loss.
julia> @info "Initial loss: $(f(θs))" # Train using gradient descent.
[ Info: Initial loss: 4704.6462559375805
julia> ε = 1e-3;
julia> for i in 1:100 (∇s,) = Zygote.gradient(f, θs) θs = fmap(θs, ∇s) do θ, ∇ θ - ε .* ∇ end end # Final loss
julia> @info "Final loss: $(f(θs))" # Very simple check to see if we learned something useful.
[ Info: Final loss: 2840.21556214402
julia> samples = rand(transformed(f.basedist, f.reconstruct(θs)), 1000);
julia> mean(eachcol(samples)) # ≈ [0, 0]
2-element Vector{Float64}: -0.02048464439206124 -0.049043655869899616
julia> cov(samples; dims=2) # ≈ I
2×2 Matrix{Float64}: 1.00564 0.0514742 0.0514742 1.07841
We can easily create more complex flows by simply doing PlanarLayer(10) ∘ PlanarLayer(10) ∘ RadialLayer(10)
and so on.