API
NormalizingFlows.elbo
NormalizingFlows.loglikelihood
NormalizingFlows.optimize
NormalizingFlows.train_flow
Main Function
NormalizingFlows.train_flow
— Functiontrain_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)
Train the given normalizing flow flow
by calling optimize
.
Arguments
rng::AbstractRNG
: random number generatorvo
: variational objectiveflow
: normalizing flow to be trained, we recommend to define flow as<:Bijectors.TransformedDistribution
args...
: additional arguments forvo
Keyword Arguments
max_iters::Int=1000
: maximum number of iterationsoptimiser::Optimisers.AbstractRule=Optimisers.ADAM()
: optimiser to compute the stepsADbackend::ADTypes.AbstractADType=ADTypes.AutoZygote()
: automatic differentiation backend, currently supportsADTypes.AutoZygote()
,ADTypes.ForwardDiff()
,ADTypes.ReverseDiff()
,ADTypes.AutoMooncake()
andADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const, )
. If user wants to useAutoEnzyme
, please make sure to include theset_runtime_activity
andfunction_annotation
as shown above.kwargs...
: additional keyword arguments foroptimize
(Seeoptimize
for details)
Returns
flow_trained
: trained normalizing flowopt_stats
: statistics of the optimiser during the training process (Seeoptimize
for details)st
: optimiser state for potential continuation of training
The flow object can be constructed by transformed
function in Bijectors.jl
package. For example of Gaussian VI, we can construct the flow as follows:
using Distributions, Bijectors
T= Float32
@leaf MvNormal # to prevent params in q₀ from being optimized
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
flow = Bijectors.transformed(q₀, Bijectors.Shift(zeros(T,2)) ∘ Bijectors.Scale(ones(T, 2)))
To train the Gaussian VI targeting at distirbution $p$ via ELBO maiximization, we can run
using NormalizingFlows
sample_per_iter = 10
flow_trained, stats, _ = train_flow(
elbo,
flow,
logp,
sample_per_iter;
max_iters=2_000,
optimiser=Optimisers.ADAM(0.01 * one(T)),
)
Variational Objectives
We have implemented two variational objectives, namely, ELBO and the log-likelihood objective. Users can also define their own objective functions, and pass it to the train_flow
function. train_flow
will optimize the flow parameters by maximizing vo
. The objective function should take the following general form:
vo(rng, flow, args...)
where rng
is the random number generator, flow
is the flow object, and args...
are the additional arguments that users can pass to the objective function.
Evidence Lower Bound (ELBO)
By maximizing the ELBO, it is equivalent to minimizing the reverse KL divergence between $q_\theta$ and $p$, i.e.,
\[\begin{aligned} &\min _{\theta} \mathbb{E}_{q_{\theta}}\left[\log q_{\theta}(Z)-\log p(Z)\right] \quad \text{(Reverse KL)}\\ & = \max _{\theta} \mathbb{E}_{q_0}\left[ \log p\left(T_N \circ \cdots \circ T_1(Z_0)\right)-\log q_0(X)+\sum_{n=1}^N \log J_n\left(F_n \circ \cdots \circ F_1(X)\right)\right] \quad \text{(ELBO)} \end{aligned}\]
Reverse KL minimization is typically used for Bayesian computation, where one only has access to the log-(unnormalized)density of the target distribution $p$ (e.g., a Bayesian posterior distribution), and hope to generate approximate samples from it.
NormalizingFlows.elbo
— Functionelbo(flow, logp, xs)
elbo([rng, ]flow, logp, n_samples)
Compute the ELBO for a batch of samples xs
from the reference distribution flow.dist
.
Arguments
rng
: random number generatorflow
: variational distribution to be trained. In particularflow = transformed(q₀, T::Bijectors.Bijector)
, q₀ is a reference distribution that one can easily sample and compute logpdflogp
: log-pdf of the target distribution (not necessarily normalized)xs
: samples from reference dist q₀n_samples
: number of samples from reference dist q₀
Log-likelihood
By maximizing the log-likelihood, it is equivalent to minimizing the forward KL divergence between $q_\theta$ and $p$, i.e.,
\[\begin{aligned} & \min_{\theta} \mathbb{E}_{p}\left[\log q_{\theta}(Z)-\log p(Z)\right] \quad \text{(Forward KL)} \\ & = \max_{\theta} \mathbb{E}_{p}\left[\log q_{\theta}(Z)\right] \quad \text{(Expected log-likelihood)} \end{aligned}\]
Forward KL minimization is typically used for generative modeling, where one is given a set of samples from the target distribution $p$ (e.g., images) and aims to learn the density or a generative process that outputs high quality samples.
NormalizingFlows.loglikelihood
— Functionloglikelihood(rng, flow::Bijectors.TransformedDistribution, xs::AbstractVecOrMat)
Compute the log-likelihood for variational distribution flow at a batch of samples xs from the target distribution p.
Arguments
rng
: random number generator (empty argument, only needed to ensure the same signature as other variational objectives)flow
: variational distribution to be trained. In particular "flow = transformed(q₀, T::Bijectors.Bijector)", q₀ is a reference distribution that one can easily sample and compute logpdfxs
: samples from the target distribution p.
Training Loop
NormalizingFlows.optimize
— Functionoptimize(
ad::ADTypes.AbstractADType,
loss,
θ₀::AbstractVector{T},
re,
args...;
kwargs...
)
Iteratively updating the parameters θ
of the normalizing flow re(θ)
by calling grad!
and using the given optimiser
to compute the steps.
Arguments
ad::ADTypes.AbstractADType
: automatic differentiation backendloss
: a general loss function θ -> loss(θ, args...) returning a scalar loss value that will be minimisedθ₀::AbstractVector{T}
: initial parameters for the loss function (in the context of normalizing flows, it will be the flattened flow parameters)re
: reconstruction function that maps the flattened parameters to the normalizing flowargs...
: additional arguments forloss
(will be set as DI.Constant)
Keyword Arguments
max_iters::Int=10000
: maximum number of iterationsoptimiser::Optimisers.AbstractRule=Optimisers.ADAM()
: optimiser to compute the stepsshow_progress::Bool=true
: whether to show the progress bar. The default information printed in the progress bar is the iteration number, the loss value, and the gradient norm.callback=nothing
: callback function with signaturecb(iter, opt_state, re, θ)
which returns a dictionary-like object of statistics to be displayed in the progress bar. re and θ are used for reconstructing the normalizing flow in case that user want to further axamine the status of the flow.hasconverged = (iter, opt_stats, re, θ, st) -> false
: function that checks whether the training has converged. The default is to always return false.prog=ProgressMeter.Progress( max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress )
: progress bar configuration
Returns
θ
: trained parameters of the normalizing flowopt_stats
: statistics of the optimiserst
: optimiser state for potential continuation of training