General usage
train_flow is the main function to train a normalizing flow.  The users mostly need to specify a normalizing flow flow,  the variational objective vo and its corresponding arguments args....
NormalizingFlows.train_flow — Functiontrain_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)Train the given normalizing flow flow by calling optimize.
Arguments
- rng::AbstractRNG: random number generator (default:- Random.default_rng())
- vo: variational objective with signature- vo(rng, flow, args...). We implement- elbo,- elbo_batch, and- loglikelihood.
- flow: the normalizing flow–-a- Bijectors.TransformedDistribution(recommended)
- args...: additional arguments passed to- vo
Keyword Arguments
- max_iters::Int=1000: maximum number of iterations
- optimiser::Optimisers.AbstractRule=Optimisers.ADAM(): optimiser to compute the steps
- ADbackend::ADTypes.AbstractADType=ADTypes.AutoZygote(): automatic differentiation backend, currently supports- ADTypes.AutoZygote(),- ADTypes.ForwardDiff(),- ADTypes.ReverseDiff(),- ADTypes.AutoMooncake()and- ADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const, ). If user wants to use- AutoEnzyme, please make sure to include the- set_runtime_activityand- function_annotationas shown above.
- kwargs...: additional keyword arguments for- optimize(See- optimizefor details)
Returns
- flow_trained: trained normalizing flow
- opt_stats: statistics of the optimiser during the training process (See- optimizefor details)
- st: optimiser state for potential continuation of training
The flow object can be constructed by transformed function in Bijectors.jl. For example, for mean-field Gaussian VI, we can construct the flow family as follows:
using Distributions, Bijectors
T = Float32
@leaf MvNormal # to prevent params in q₀ from being optimized
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
# the flow family is defined by a shift and a scale 
flow = Bijectors.transformed(q₀, Bijectors.Shift(zeros(T,2)) ∘ Bijectors.Scale(ones(T, 2)))To train the Gaussian VI targeting distribution p via ELBO maximization, run:
using NormalizingFlows, Optimisers
using ADTypes, Mooncake
sample_per_iter = 10
flow_trained, stats, _ = train_flow(
    elbo,
    flow,
    logp,
    sample_per_iter;
    max_iters=5_000,
    optimiser=Optimisers.Adam(one(T)/100),
    ADbackend=ADTypes.AutoMooncake(; config=Mooncake.Config()),
    show_progress=true,
)