Univariate ADVI example
But the real utility of TransformedDistribution
becomes more apparent when using transformed(dist, b)
for any bijector b
. To get the transformed distribution corresponding to the Beta(2, 2)
, we called transformed(dist)
before. This is simply an alias for transformed(dist, bijector(dist))
. Remember bijector(dist)
returns the constrained-to-constrained bijector for that particular Distribution
. But we can of course construct a TransformedDistribution
using different bijectors with the same dist
. This is particularly useful in something called Automatic Differentiation Variational Inference (ADVI).[2] An important part of ADVI is to approximate a constrained distribution, e.g. Beta
, as follows:
- Sample
x
from aNormal
with parametersμ
andσ
, i.e.x ~ Normal(μ, σ)
. - Transform
x
toy
s.t.y ∈ support(Beta)
, with the transform being a differentiable bijection with a differentiable inverse (a "bijector")
This then defines a probability density with same support as Beta
! Of course, it's unlikely that it will be the same density, but it's an approximation. Creating such a distribution becomes trivial with Bijector
and TransformedDistribution
:
julia> using StableRNGs: StableRNG
julia> rng = StableRNG(42);
julia> dist = Beta(2, 2)
Beta{Float64}(α=2.0, β=2.0)
julia> b = bijector(dist) # (0, 1) → ℝ
Bijectors.Logit{Float64, Float64}(0.0, 1.0)
julia> b⁻¹ = inverse(b) # ℝ → (0, 1)
Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0))
julia> td = transformed(Normal(), b⁻¹) # x ∼ 𝓝(0, 1) then b(x) ∈ (0, 1)
UnivariateTransformed{Normal{Float64}, Inverse{Bijectors.Logit{Float64, Float64}}}( dist: Normal{Float64}(μ=0.0, σ=1.0) transform: Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0)) )
julia> x = rand(rng, td) # ∈ (0, 1)
0.3384404850130036
It's worth noting that support(Beta)
is the closed interval [0, 1]
, while the constrained-to-unconstrained bijection, Logit
in this case, is only well-defined as a map (0, 1) → ℝ
for the open interval (0, 1)
. This is of course not an implementation detail. ℝ
is itself open, thus no continuous bijection exists from a closed interval to ℝ
. But since the boundaries of a closed interval has what's known as measure zero, this doesn't end up affecting the resulting density with support on the entire real line. In practice, this means that
julia> td = transformed(Beta())
UnivariateTransformed{Beta{Float64}, Bijectors.Logit{Float64, Float64}}( dist: Beta{Float64}(α=1.0, β=1.0) transform: Bijectors.Logit{Float64, Float64}(0.0, 1.0) )
julia> inverse(td.transform)(rand(rng, td))
0.8130302707446476
will never result in 0
or 1
though any sample arbitrarily close to either 0
or 1
is possible. Disclaimer: numerical accuracy is limited, so you might still see 0
and 1
if you're lucky.
Multivariate ADVI example
We can also do multivariate ADVI using the Stacked
bijector. Stacked
gives us a way to combine univariate and/or multivariate bijectors into a singe multivariate bijector. Say you have a vector x
of length 2 and you want to transform the first entry using Exp
and the second entry using Log
. Stacked
gives you an easy and efficient way of representing such a bijector.
julia> using Bijectors: SimplexBijector # Original distributions
julia> dists = (Beta(), InverseGamma(), Dirichlet(2, 3)); # Construct the corresponding ranges
julia> ranges = [];
julia> idx = 1;
julia> for i in 1:length(dists) d = dists[i] push!(ranges, idx:(idx + length(d) - 1)) global idx idx += length(d) end;
julia> ranges # Base distribution; mean-field normal
3-element Vector{Any}: 1:1 2:2 3:4
julia> num_params = ranges[end][end]
4
julia> d = MvNormal(zeros(num_params), ones(num_params)); # Construct the transform
julia> bs = bijector.(dists); # constrained-to-unconstrained bijectors for dists
julia> ibs = inverse.(bs); # invert, so we get unconstrained-to-constrained
julia> sb = Stacked(ibs, ranges) # => Stacked <: Bijector # Mean-field normal with unconstrained-to-constrained stacked bijector
Stacked(Any[Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0)), Base.Fix1{typeof(broadcast), typeof(exp)}(broadcast, exp), Inverse{Bijectors.SimplexBijector}(Bijectors.SimplexBijector())], Any[1:1, 2:2, 3:4], Any[1:1, 2:2, 3:5])
julia> td = transformed(d, sb);
julia> y = rand(td)
5-element Vector{Float64}: 0.17841505686495443 1.5336323424298905 0.7156704776573194 0.04085757365673251 0.24347194868594813
julia> 0.0 ≤ y[1] ≤ 1.0
true
julia> 0.0 < y[2]
true
julia> sum(y[3:4]) ≈ 1.0
false
Normalizing flows
A very interesting application is that of normalizing flows.[1] Usually this is done by sampling from a multivariate normal distribution, and then transforming this to a target distribution using invertible neural networks. Currently there are two such transforms available in Bijectors.jl: PlanarLayer
and RadialLayer
. Let's create a flow with a single PlanarLayer
:
julia> d = MvNormal(zeros(2), ones(2));
julia> b = PlanarLayer(2)
PlanarLayer(w = [-1.0408637312049929, 2.235251162004739], u = [-0.24164166198564066, 1.652917908177465], b = [0.12060496907886155])
julia> flow = transformed(d, b)
MultivariateTransformed{DiagNormal, PlanarLayer{Vector{Float64}, Vector{Float64}}}( dist: DiagNormal( dim: 2 μ: [0.0, 0.0] Σ: [1.0 0.0; 0.0 1.0] ) transform: PlanarLayer(w = [-1.0408637312049929, 2.235251162004739], u = [-0.24164166198564066, 1.652917908177465], b = [0.12060496907886155]) )
julia> flow isa MultivariateDistribution
true
That's it. Now we can sample from it using rand
and compute the logpdf
, like any other Distribution
.
julia> y = rand(rng, flow)
2-element Vector{Float64}: -0.7401802806074489 1.6730062227158746
julia> logpdf(flow, y) # uses inverse of `b`
-2.4224845733786786
Similarily to the multivariate ADVI example, we could use Stacked
to get a bounded flow:
julia> d = MvNormal(zeros(2), ones(2));
julia> ibs = inverse.(bijector.((InverseGamma(2, 3), Beta())));
julia> sb = Stacked(ibs) # == Stacked(ibs, [i:i for i = 1:length(ibs)]
Stacked((Base.Fix1{typeof(broadcast), typeof(exp)}(broadcast, exp), Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0))), (1:1, 2:2), (1:1, 2:2))
julia> b = sb ∘ PlanarLayer(2)
Stacked((Base.Fix1{typeof(broadcast), typeof(exp)}(broadcast, exp), Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0))), (1:1, 2:2), (1:1, 2:2)) ∘ PlanarLayer(w = [0.2220050576914493, -0.1436477237262901], u = [-1.4152050850988749, 0.17109500087483392], b = [-0.714064037930183])
julia> td = transformed(d, b);
julia> y = rand(rng, td)
2-element Vector{Float64}: 10.38293034211274 0.7469793180100731
julia> 0 < y[1]
true
julia> 0 ≤ y[2] ≤ 1
true
Want to fit the flow?
julia> using Zygote # Construct the flow.
julia> b = PlanarLayer(2) # Convenient for extracting parameters and reconstructing the flow.
PlanarLayer(w = [1.7754907696739541, -0.5938276464383798], u = [1.3385218771986869, 2.185666507949883], b = [0.23493273069341675])
julia> using Functors
julia> θs, reconstruct = Functors.functor(b); # Make the objective a `struct` to avoid capturing global variables.
julia> struct NLLObjective{R,D,T} reconstruct::R basedist::D data::T end
julia> function (obj::NLLObjective)(θs) transformed_dist = transformed(obj.basedist, obj.reconstruct(θs)) return -sum(Base.Fix1(logpdf, transformed_dist), eachcol(obj.data)) end # Some random data to estimate the density of.
julia> xs = randn(2, 1000); # Construct the objective.
julia> f = NLLObjective(reconstruct, MvNormal(2, 1), xs); # Initial loss.
julia> @info "Initial loss: $(f(θs))" # Train using gradient descent.
[ Info: Initial loss: 4518.38908626241
julia> ε = 1e-3;
julia> for i in 1:100 (∇s,) = Zygote.gradient(f, θs) θs = fmap(θs, ∇s) do θ, ∇ θ - ε .* ∇ end end # Final loss
julia> @info "Final loss: $(f(θs))" # Very simple check to see if we learned something useful.
[ Info: Final loss: 2862.6025267572877
julia> samples = rand(transformed(f.basedist, f.reconstruct(θs)), 1000);
julia> mean(eachcol(samples)) # ≈ [0, 0]
2-element Vector{Float64}: -0.08048446136305033 0.03643599036394901
julia> cov(samples; dims=2) # ≈ I
2×2 Matrix{Float64}: 0.991613 0.0192021 0.0192021 1.02701
We can easily create more complex flows by simply doing PlanarLayer(10) ∘ PlanarLayer(10) ∘ RadialLayer(10)
and so on.