using Random, Distributions, LinearAlgebra
using Functors
using Optimisers, ADTypes
using Mooncake
using NormalizingFlows
target = Banana(2, one(T), 100one(T))
logp = Base.Fix1(logpdf, target)
######################################
# set up the RealNVP
######################################
@leaf MvNormal
q0 = MvNormal(zeros(T, 2), I)
d = 2
hdims = [16, 16]
nlayers = 3
# use NormalizingFlows.realnvp to create a RealNVP flow
flow = realnvp(q0, hdims, nlayers; paramtype=T)
flow_untrained = deepcopy(flow)
######################################
# start training
######################################
sample_per_iter = 16
# callback function to log training progress
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
flow_trained, stats, _ = train_flow(
rng,
elbo, # using elbo_batch instead of elbo achieves 4-5 times speedup
flow,
logp,
sample_per_iter;
max_iters=10, # change to larger number of iterations (e.g., 50_000) for better results
optimiser=Optimisers.Adam(5e-4),
ADbackend=adtype,
show_progress=true,
callback=cb,
hasconverged=checkconv,
)
θ, re = Optimisers.destructure(flow_trained)
losses = map(x -> x.loss, stats)