General usage

train_flow is the main function to train a normalizing flow. The users mostly need to specify a normalizing flow flow, the variational objective vo and its corresponding arguments args....

NormalizingFlows.train_flowFunction
train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)

Train the given normalizing flow flow by calling optimize.

Arguments

  • rng::AbstractRNG: random number generator (default: Random.default_rng())
  • vo: variational objective with signature vo(rng, flow, args...). We implement elbo, elbo_batch, and loglikelihood.
  • flow: the normalizing flow–-a Bijectors.TransformedDistribution (recommended)
  • args...: additional arguments passed to vo

Keyword Arguments

  • max_iters::Int=1000: maximum number of iterations
  • optimiser::Optimisers.AbstractRule=Optimisers.ADAM(): optimiser to compute the steps
  • ADbackend::ADTypes.AbstractADType=ADTypes.AutoZygote(): automatic differentiation backend, currently supports ADTypes.AutoZygote(), ADTypes.ForwardDiff(), ADTypes.ReverseDiff(), ADTypes.AutoMooncake() and ADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const, ). If user wants to use AutoEnzyme, please make sure to include the set_runtime_activity and function_annotation as shown above.
  • kwargs...: additional keyword arguments for optimize (See optimize for details)

Returns

  • flow_trained: trained normalizing flow
  • opt_stats: statistics of the optimiser during the training process (See optimize for details)
  • st: optimiser state for potential continuation of training
source

The flow object can be constructed by transformed function in Bijectors.jl. For example, for mean-field Gaussian VI, we can construct the flow family as follows:

using Distributions, Bijectors
T = Float32
@leaf MvNormal # to prevent params in q₀ from being optimized
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
# the flow family is defined by a shift and a scale 
flow = Bijectors.transformed(q₀, Bijectors.Shift(zeros(T,2)) ∘ Bijectors.Scale(ones(T, 2)))

To train the Gaussian VI targeting distribution p via ELBO maximization, run:

using NormalizingFlows, Optimisers
using ADTypes, Mooncake

sample_per_iter = 10
flow_trained, stats, _ = train_flow(
    elbo,
    flow,
    logp,
    sample_per_iter;
    max_iters=5_000,
    optimiser=Optimisers.Adam(one(T)/100),
    ADbackend=ADTypes.AutoMooncake(; config=Mooncake.Config()),
    show_progress=true,
)