General usage
train_flow
is the main function to train a normalizing flow. The users mostly need to specify a normalizing flow flow
, the variational objective vo
and its corresponding arguments args...
.
NormalizingFlows.train_flow
— Functiontrain_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)
Train the given normalizing flow flow
by calling optimize
.
Arguments
rng::AbstractRNG
: random number generator (default:Random.default_rng()
)vo
: variational objective with signaturevo(rng, flow, args...)
. We implementelbo
,elbo_batch
, andloglikelihood
.flow
: the normalizing flow–-aBijectors.TransformedDistribution
(recommended)args...
: additional arguments passed tovo
Keyword Arguments
max_iters::Int=1000
: maximum number of iterationsoptimiser::Optimisers.AbstractRule=Optimisers.ADAM()
: optimiser to compute the stepsADbackend::ADTypes.AbstractADType=ADTypes.AutoZygote()
: automatic differentiation backend, currently supportsADTypes.AutoZygote()
,ADTypes.ForwardDiff()
,ADTypes.ReverseDiff()
,ADTypes.AutoMooncake()
andADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const, )
. If user wants to useAutoEnzyme
, please make sure to include theset_runtime_activity
andfunction_annotation
as shown above.kwargs...
: additional keyword arguments foroptimize
(Seeoptimize
for details)
Returns
flow_trained
: trained normalizing flowopt_stats
: statistics of the optimiser during the training process (Seeoptimize
for details)st
: optimiser state for potential continuation of training
The flow object can be constructed by transformed
function in Bijectors.jl
. For example, for mean-field Gaussian VI, we can construct the flow family as follows:
using Distributions, Bijectors
T = Float32
@leaf MvNormal # to prevent params in q₀ from being optimized
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
# the flow family is defined by a shift and a scale
flow = Bijectors.transformed(q₀, Bijectors.Shift(zeros(T,2)) ∘ Bijectors.Scale(ones(T, 2)))
To train the Gaussian VI targeting distribution p
via ELBO maximization, run:
using NormalizingFlows, Optimisers
using ADTypes, Mooncake
sample_per_iter = 10
flow_trained, stats, _ = train_flow(
elbo,
flow,
logp,
sample_per_iter;
max_iters=5_000,
optimiser=Optimisers.Adam(one(T)/100),
ADbackend=ADTypes.AutoMooncake(; config=Mooncake.Config()),
show_progress=true,
)