General usage
train_flow is the main function to train a normalizing flow. The users mostly need to specify a normalizing flow flow, the variational objective vo and its corresponding arguments args....
NormalizingFlows.train_flow — Functiontrain_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)Train the given normalizing flow flow by calling optimize.
Arguments
rng::AbstractRNG: random number generator (default:Random.default_rng())vo: variational objective with signaturevo(rng, flow, args...). We implementelbo,elbo_batch, andloglikelihood.flow: the normalizing flow–-aBijectors.TransformedDistribution(recommended)args...: additional arguments passed tovo
Keyword Arguments
max_iters::Int=1000: maximum number of iterationsoptimiser::Optimisers.AbstractRule=Optimisers.ADAM(): optimiser to compute the stepsADbackend::ADTypes.AbstractADType=ADTypes.AutoZygote(): automatic differentiation backend, currently supportsADTypes.AutoZygote(),ADTypes.ForwardDiff(),ADTypes.ReverseDiff(),ADTypes.AutoMooncake()andADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const, ). If user wants to useAutoEnzyme, please make sure to include theset_runtime_activityandfunction_annotationas shown above.kwargs...: additional keyword arguments foroptimize(Seeoptimizefor details)
Returns
flow_trained: trained normalizing flowopt_stats: statistics of the optimiser during the training process (Seeoptimizefor details)st: optimiser state for potential continuation of training
The flow object can be constructed by transformed function in Bijectors.jl. For example, for mean-field Gaussian VI, we can construct the flow family as follows:
using Distributions, Bijectors
T = Float32
@leaf MvNormal # to prevent params in q₀ from being optimized
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
# the flow family is defined by a shift and a scale
flow = Bijectors.transformed(q₀, Bijectors.Shift(zeros(T,2)) ∘ Bijectors.Scale(ones(T, 2)))To train the Gaussian VI targeting distribution p via ELBO maximization, run:
using NormalizingFlows, Optimisers
using ADTypes, Mooncake
sample_per_iter = 10
flow_trained, stats, _ = train_flow(
elbo,
flow,
logp,
sample_per_iter;
max_iters=5_000,
optimiser=Optimisers.Adam(one(T)/100),
ADbackend=ADTypes.AutoMooncake(; config=Mooncake.Config()),
show_progress=true,
)