Bayesian Neural Networks
In this tutorial, we demonstrate how one can implement a Bayesian Neural Network using a combination of Turing and Flux, a suite of machine learning tools. We will use Flux to specify the neural network's layers and Turing to implement the probabilistic inference, with the goal of implementing a classification algorithm.
We will begin with importing the relevant libraries.
using Turing
using FillArrays
using Flux
using Plots
using ReverseDiff
using LinearAlgebra
using Random
# Use reverse_diff due to the number of parameters in neural networks.
Turing.setadbackend(:reversediff)
:reversediff
Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we will be working with.
# Number of points to generate.
N = 80
M = round(Int, N / 4)
Random.seed!(1234)
# Generate artificial data.
x1s = rand(M) * 4.5;
x2s = rand(M) * 4.5;
xt1s = Array([[x1s[i] + 0.5; x2s[i] + 0.5] for i in 1:M])
x1s = rand(M) * 4.5;
x2s = rand(M) * 4.5;
append!(xt1s, Array([[x1s[i] - 5; x2s[i] - 5] for i in 1:M]))
x1s = rand(M) * 4.5;
x2s = rand(M) * 4.5;
xt0s = Array([[x1s[i] + 0.5; x2s[i] - 5] for i in 1:M])
x1s = rand(M) * 4.5;
x2s = rand(M) * 4.5;
append!(xt0s, Array([[x1s[i] - 5; x2s[i] + 0.5] for i in 1:M]))
# Store all the data for later.
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]
# Plot data points.
function plot_data()
x1 = map(e -> e[1], xt1s)
y1 = map(e -> e[2], xt1s)
x2 = map(e -> e[1], xt0s)
y2 = map(e -> e[2], xt0s)
Plots.scatter(x1, y1; color="red", clim=(0, 1))
return Plots.scatter!(x2, y2; color="blue", clim=(0, 1))
end
plot_data()
Building a Neural Network
The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks.
For this we will use Dense
to define liner layers and compose them via Chain
, both are neural network primitives from Flux.
The network nn_initial
we created has two hidden layers with tanh
activations and one output layer with sigmoid (σ
) activation, as shown below.
The nn_initial
is an instance that acts as a function and can take data as inputs and output predictions.
We will define distributions on the neural network parameters and use destructure
from Flux to extract the parameters as parameters_initial
.
The function destructure
also returns another function reconstruct
that can take (new) parameters in and return us a neural network instance whose architecture is the same as nn_initial
but with updated parameters.
# Construct a neural network using Flux
nn_initial = Chain(Dense(2, 3, tanh), Dense(3, 2, tanh), Dense(2, 1, σ))
# Extract weights and a helper function to reconstruct NN from weights
parameters_initial, reconstruct = Flux.destructure(nn_initial)
length(parameters_initial) # number of paraemters in NN
20
The probabilistic model specification below creates a parameters
variable, which has IID normal variables. The parameters
vector represents all parameters of our neural net (weights and biases).
@model function bayes_nn(xs, ts, nparameters, reconstruct; alpha=0.09)
# Create the weight and bias vector.
parameters ~ MvNormal(Zeros(nparameters), I / alpha)
# Construct NN from parameters
nn = reconstruct(parameters)
# Forward NN to make predictions
preds = nn(xs)
# Observe each prediction.
for i in 1:length(ts)
ts[i] ~ Bernoulli(preds[i])
end
end;
Inference can now be performed by calling sample
. We use the NUTS
Hamiltonian Monte Carlo sampler here.
# Perform inference.
N = 5000
ch = sample(bayes_nn(hcat(xs...), ts, length(parameters_initial), reconstruct), NUTS(), N);
Now we extract the parameter samples from the sampled chain as theta
(this is of size 5000 x 20
where 5000
is the number of iterations and 20
is the number of parameters).
We'll use these primarily to determine how good our model's classifier is.
# Extract all weight and bias parameters.
theta = MCMCChains.group(ch, :parameters).value;
Prediction Visualization
We can use MAP estimation to classify our population by using the set of weights that provided the highest log posterior.
# A helper to create NN from weights `theta` and run it through data `x`
nn_forward(x, theta) = reconstruct(theta)(x)
# Plot the data we have.
plot_data()
# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])
# Extract the max row value from i.
i = i.I[1]
# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z)
The contour plot above shows that the MAP method is not too bad at classifying our data.
Now we can visualize our predictions.
$$ p(\tilde{x} | X, \alpha) = \int_{\theta} p(\tilde{x} | \theta) p(\theta | X, \alpha) \approx \sum_{\theta \sim p(\theta | X, \alpha)}f_{\theta}(\tilde{x}) $$
The nn_predict
function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.
# Return the average predicted value across
# multiple weights.
function nn_predict(x, theta, num)
return mean([nn_forward(x, theta[i, :])[1] for i in 1:10:num])
end;
Next, we use the nn_predict
function to predict the value at a sample of points where the x1
and x2
coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier---those regions between cluster boundaries.
# Plot the average prediction.
plot_data()
n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], theta, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z)
Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 1,000.
# Number of iterations to plot.
n_end = 500
anim = @gif for i in 1:n_end
plot_data()
Z = [nn_forward([x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; title="Iteration $i", clim=(0, 1))
end every 5
This has been an introduction to the applications of Turing and Flux in defining Bayesian neural networks.
Appendix
These tutorials are a part of the TuringTutorials repository, found at: https://github.com/TuringLang/TuringTutorials.
To locally run this tutorial, do the following commands:
using TuringTutorials
TuringTutorials.weave("03-bayesian-neural-network", "03_bayesian-neural-network.jmd")
Computer Information:
Julia Version 1.6.7
Commit 3b76b25b64 (2022-07-19 15:11 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-11.0.1 (ORCJIT, znver2)
Environment:
JULIA_CPU_THREADS = 16
BUILDKITE_PLUGIN_JULIA_CACHE_DIR = /cache/julia-buildkite-plugin
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/7aa0085e-79a4-45f3-a5bd-9743c91cf3da
Package Information:
Status `/cache/build/default-amdci4-2/julialang/turingtutorials/tutorials/03-bayesian-neural-network/Project.toml`
[b5ca4192] AdvancedVI v0.1.6
[76274a88] Bijectors v0.10.8
[1a297f60] FillArrays v0.13.10
[587475ba] Flux v0.13.14
[91a5bcdd] Plots v1.38.8
[37e2e3b7] ReverseDiff v1.14.4
[fce5fe82] Turing v0.22.0
[37e2e46d] LinearAlgebra
[9a3f8284] Random
And the full manifest:
Status `/cache/build/default-amdci4-2/julialang/turingtutorials/tutorials/03-bayesian-neural-network/Manifest.toml`
[621f4979] AbstractFFTs v1.3.1
[80f14c24] AbstractMCMC v4.2.0
[7a57a42e] AbstractPPL v0.5.4
[1520ce14] AbstractTrees v0.4.4
[7d9f7c33] Accessors v0.1.28
[79e6a3ab] Adapt v3.6.1
[0bf59076] AdvancedHMC v0.3.6
[5b7e9947] AdvancedMH v0.6.8
[576499cb] AdvancedPS v0.3.8
[b5ca4192] AdvancedVI v0.1.6
[dce04be8] ArgCheck v2.3.0
[4fba245c] ArrayInterface v7.3.1
[a9b6321e] Atomix v0.1.0
[13072b0f] AxisAlgorithms v1.0.1
[39de3d68] AxisArrays v0.4.6
[ab4f0b2a] BFloat16s v0.4.2
[198e06fe] BangBang v0.3.37
[9718e550] Baselet v0.1.1
[76274a88] Bijectors v0.10.8
[fa961155] CEnum v0.4.2
[052768ef] CUDA v4.1.0
[1af6417a] CUDA_Runtime_Discovery v0.1.1
[49dc2e85] Calculus v0.5.1
[082447d4] ChainRules v1.48.0
[d360d2e6] ChainRulesCore v1.15.7
[9e997f8a] ChangesOfVariables v0.1.6
[35d6a980] ColorSchemes v3.20.0
[3da002f7] ColorTypes v0.11.4
[c3611d14] ColorVectorSpace v0.9.10
[5ae59095] Colors v0.12.10
[861a8166] Combinatorics v1.0.2
[38540f10] CommonSolve v0.2.3
[bbf7d656] CommonSubexpressions v0.3.0
[34da2185] Compat v4.6.1
[a33af91c] CompositionsBase v0.1.1
[88cd18e8] ConsoleProgressMonitor v0.1.2
[187b0558] ConstructionBase v1.5.1
[6add18c4] ContextVariablesX v0.1.3
[d38c429a] Contour v0.6.2
[a8cc5b0e] Crayons v4.1.1
[9a962f9c] DataAPI v1.14.0
[864edb3b] DataStructures v0.18.13
[e2d170a0] DataValueInterfaces v1.0.0
[244e2a9f] DefineSingletons v0.1.2
[b429d917] DensityInterface v0.4.0
[163ba53b] DiffResults v1.1.0
[b552c78f] DiffRules v1.13.0
[31c24e10] Distributions v0.25.86
[ced4e74d] DistributionsAD v0.6.43
[ffbed154] DocStringExtensions v0.9.3
[fa6b7ba4] DualNumbers v0.6.8
[366bfd00] DynamicPPL v0.21.4
[cad2338a] EllipticalSliceSampling v1.1.0
[4e289a0a] EnumX v1.0.4
[e2ba6199] ExprTools v0.1.9
[c87230d0] FFMPEG v0.4.1
[7a1cc6ca] FFTW v1.6.0
[cc61a311] FLoops v0.2.1
[b9860ae5] FLoopsBase v0.1.1
[1a297f60] FillArrays v0.13.10
[53c48c17] FixedPointNumbers v0.8.4
[587475ba] Flux v0.13.14
[9c68100b] FoldsThreads v0.1.1
[59287772] Formatting v0.4.2
[f6369f11] ForwardDiff v0.10.35
[069b7b12] FunctionWrappers v1.1.3
[77dc65aa] FunctionWrappersWrappers v0.1.3
[d9f16b24] Functors v0.3.0
[0c68f7d7] GPUArrays v8.6.5
[46192b85] GPUArraysCore v0.1.4
[61eb1bfa] GPUCompiler v0.18.0
[28b8d3ca] GR v0.71.8
[42e2da0e] Grisu v1.0.2
[cd3eb016] HTTP v0.8.19
[34004b35] HypergeometricFunctions v0.3.11
[7869d1d1] IRTools v0.4.9
[83e8ac13] IniFile v0.5.1
[22cec73e] InitialValues v0.3.1
[505f98c9] InplaceOps v0.3.0
[a98d9a8b] Interpolations v0.14.7
[8197267c] IntervalSets v0.7.4
[3587e190] InverseFunctions v0.1.8
[41ab1584] InvertedIndices v1.3.0
[92d709cd] IrrationalConstants v0.2.2
[c8e1da08] IterTools v1.4.0
[82899510] IteratorInterfaceExtensions v1.0.0
[1019f520] JLFzf v0.1.5
[692b3bcd] JLLWrappers v1.4.1
[682c06a0] JSON v0.21.3
[b14d175d] JuliaVariables v0.2.4
[63c18a36] KernelAbstractions v0.9.0
[5ab0869b] KernelDensity v0.6.5
[929cbde3] LLVM v4.17.1
[8ac3fa9e] LRUCache v1.4.0
[b964fa9f] LaTeXStrings v1.3.0
[23fbe1c1] Latexify v0.15.18
[50d2b5c4] Lazy v0.15.1
[1d6d02ad] LeftChildRightSiblingTrees v0.2.0
[6f1fad26] Libtask v0.7.0
[6fdf6af0] LogDensityProblems v1.0.3
[2ab3a3ac] LogExpFunctions v0.3.23
[e6f89c97] LoggingExtras v0.4.9
[c7f686f2] MCMCChains v5.7.1
[be115224] MCMCDiagnosticTools v0.2.6
[e80e1ace] MLJModelInterface v1.8.0
[d8e11817] MLStyle v0.4.17
[f1d291b0] MLUtils v0.4.1
[1914dd2f] MacroTools v0.5.10
[dbb5928d] MappedArrays v0.4.1
[739be429] MbedTLS v1.1.7
[442fdcdd] Measures v0.3.2
[128add7d] MicroCollections v0.1.4
[e1d29d7a] Missings v1.1.0
[872c559c] NNlib v0.8.19
[a00861dc] NNlibCUDA v0.2.7
[77ba4419] NaNMath v1.0.2
[71a1bf82] NameResolution v0.1.5
[86f7a689] NamedArrays v0.9.7
[c020b1a1] NaturalSort v1.0.0
[6fe1bfb0] OffsetArrays v1.12.9
[0b1bfda6] OneHotArrays v0.2.3
[3bd65402] Optimisers v0.2.15
[bac558e1] OrderedCollections v1.4.1
[90014a1f] PDMats v0.11.17
[69de0a69] Parsers v2.5.8
[b98c9c47] Pipe v1.3.0
[ccf2f8ad] PlotThemes v3.1.0
[995b91a9] PlotUtils v1.3.4
[91a5bcdd] Plots v1.38.8
[21216c6a] Preferences v1.3.0
[8162dcfd] PrettyPrint v0.2.0
[08abe8d2] PrettyTables v2.2.3
[33c8b6b6] ProgressLogging v0.1.4
[92933f4c] ProgressMeter v1.7.2
[1fd47b50] QuadGK v2.8.2
[74087812] Random123 v1.6.0
[e6cf234a] RandomNumbers v1.5.3
[b3c3ace0] RangeArrays v0.3.2
[c84ed2f1] Ratios v0.4.3
[c1ae055f] RealDot v0.1.0
[3cdcf5f2] RecipesBase v1.3.3
[01d81517] RecipesPipeline v0.6.11
[731186ca] RecursiveArrayTools v2.38.0
[189a3867] Reexport v1.2.2
[05181044] RelocatableFolders v1.0.0
[ae029012] Requires v1.3.0
[37e2e3b7] ReverseDiff v1.14.4
[79098fc4] Rmath v0.7.1
[f2b01f46] Roots v2.0.10
[7e49a35a] RuntimeGeneratedFunctions v0.5.6
[0bca4576] SciMLBase v1.91.3
[c0aeaf25] SciMLOperators v0.2.0
[30f210dd] ScientificTypesBase v3.0.0
[6c6a2e73] Scratch v1.2.0
[efcf1570] Setfield v1.1.1
[605ecd9f] ShowCases v0.1.0
[992d4aef] Showoff v1.0.3
[699a6c99] SimpleTraits v0.9.4
[66db9d55] SnoopPrecompile v1.0.3
[a2af1166] SortingAlgorithms v1.1.0
[276daf66] SpecialFunctions v2.2.0
[171d559e] SplittablesBase v0.1.15
[90137ffa] StaticArrays v1.5.19
[1e83bf80] StaticArraysCore v1.4.0
[64bff920] StatisticalTraits v3.2.0
[82ae8749] StatsAPI v1.5.0
[2913bbd2] StatsBase v0.33.21
[4c63d2b9] StatsFuns v1.3.0
[892a3eda] StringManipulation v0.3.0
[09ab397b] StructArrays v0.6.15
[2efcf032] SymbolicIndexingInterface v0.2.2
[3783bdb8] TableTraits v1.0.1
[bd369af6] Tables v1.10.1
[62fd8b95] TensorCore v0.1.1
[5d786b92] TerminalLoggers v0.1.6
[a759f4b9] TimerOutputs v0.5.22
[9f7883ad] Tracker v0.2.23
[28d57a85] Transducers v0.4.75
[410a4b4d] Tricks v0.1.6
[781d530d] TruncatedStacktraces v1.3.0
[fce5fe82] Turing v0.22.0
[3a884ed6] UnPack v1.0.2
[1cfade01] UnicodeFun v0.4.1
[013be700] UnsafeAtomics v0.2.1
[d80eeb9a] UnsafeAtomicsLLVM v0.1.0
[41fe7b60] Unzip v0.1.2
[efce3f68] WoodburyMatrices v0.5.5
[e88e6eb3] Zygote v0.6.59
[700de1a5] ZygoteRules v0.2.3
[02a925ec] cuDNN v1.0.2
[6e34b625] Bzip2_jll v1.0.8+0
[4ee394cb] CUDA_Driver_jll v0.4.0+2
[76a88914] CUDA_Runtime_jll v0.4.0+2
[62b44479] CUDNN_jll v8.8.1+0
[83423d85] Cairo_jll v1.16.1+1
[2e619515] Expat_jll v2.4.8+0
[b22a6f82] FFMPEG_jll v4.4.2+2
[f5851436] FFTW_jll v3.3.10+0
[a3f928ae] Fontconfig_jll v2.13.93+0
[d7e528f0] FreeType2_jll v2.10.4+0
[559328eb] FriBidi_jll v1.0.10+0
[0656b61e] GLFW_jll v3.3.8+0
[d2c73de3] GR_jll v0.71.8+0
[78b55507] Gettext_jll v0.21.0+0
[7746bdde] Glib_jll v2.74.0+2
[3b182d85] Graphite2_jll v1.3.14+0
[2e76f6c2] HarfBuzz_jll v2.8.1+1
[1d5cc7b8] IntelOpenMP_jll v2018.0.3+2
[aacddb02] JpegTurbo_jll v2.1.91+0
[c1c5ebd0] LAME_jll v3.100.1+0
[88015f11] LERC_jll v3.0.0+1
[dad2f222] LLVMExtra_jll v0.0.18+0
[dd4b983a] LZO_jll v2.10.1+0
[e9f186c6] Libffi_jll v3.2.2+1
[d4300ac3] Libgcrypt_jll v1.8.7+0
[7e76a0d4] Libglvnd_jll v1.6.0+0
[7add5ba3] Libgpg_error_jll v1.42.0+0
[94ce4f54] Libiconv_jll v1.16.1+2
[4b2f31a3] Libmount_jll v2.35.0+0
[89763e89] Libtiff_jll v4.4.0+0
[38a345b3] Libuuid_jll v2.36.0+0
[856f044c] MKL_jll v2022.2.0+0
[e7412a2a] Ogg_jll v1.3.5+1
[458c3c95] OpenSSL_jll v1.1.20+0
[efe28fd5] OpenSpecFun_jll v0.5.5+0
[91d4177d] Opus_jll v1.3.2+0
[30392449] Pixman_jll v0.40.1+0
[ea2cea3b] Qt5Base_jll v5.15.3+2
[f50d1b31] Rmath_jll v0.4.0+0
[a2964d1f] Wayland_jll v1.21.0+0
[2381bf8a] Wayland_protocols_jll v1.25.0+0
[02c8fc9c] XML2_jll v2.10.3+0
[aed1982a] XSLT_jll v1.1.34+0
[4f6342f7] Xorg_libX11_jll v1.6.9+4
[0c0b7dd1] Xorg_libXau_jll v1.0.9+4
[935fb764] Xorg_libXcursor_jll v1.2.0+4
[a3789734] Xorg_libXdmcp_jll v1.1.3+4
[1082639a] Xorg_libXext_jll v1.3.4+4
[d091e8ba] Xorg_libXfixes_jll v5.0.3+4
[a51aa0fd] Xorg_libXi_jll v1.7.10+4
[d1454406] Xorg_libXinerama_jll v1.1.4+4
[ec84b674] Xorg_libXrandr_jll v1.5.2+4
[ea2f1a96] Xorg_libXrender_jll v0.9.10+4
[14d82f49] Xorg_libpthread_stubs_jll v0.1.0+3
[c7cfdc94] Xorg_libxcb_jll v1.13.0+3
[cc61e674] Xorg_libxkbfile_jll v1.1.0+4
[12413925] Xorg_xcb_util_image_jll v0.4.0+1
[2def613f] Xorg_xcb_util_jll v0.4.0+1
[975044d2] Xorg_xcb_util_keysyms_jll v0.4.0+1
[0d47668e] Xorg_xcb_util_renderutil_jll v0.3.9+1
[c22f9ab0] Xorg_xcb_util_wm_jll v0.4.1+1
[35661453] Xorg_xkbcomp_jll v1.4.2+4
[33bec58e] Xorg_xkeyboard_config_jll v2.27.0+4
[c5fb5394] Xorg_xtrans_jll v1.4.0+3
[3161d3a3] Zstd_jll v1.5.4+0
[214eeab7] fzf_jll v0.29.0+0
[a4ae2306] libaom_jll v3.4.0+0
[0ac62f75] libass_jll v0.15.1+0
[f638f0a6] libfdk_aac_jll v2.0.2+0
[b53b4c65] libpng_jll v1.6.38+0
[f27f6e37] libvorbis_jll v1.3.7+1
[1270edf5] x264_jll v2021.5.5+0
[dfaa095f] x265_jll v3.5.0+0
[d8fb68d0] xkbcommon_jll v1.4.1+0
[0dad84c5] ArgTools
[56f22d72] Artifacts
[2a0f44e3] Base64
[ade2ca70] Dates
[8bb1440f] DelimitedFiles
[8ba89e20] Distributed
[f43a241f] Downloads
[9fa8497b] Future
[b77e0a4c] InteractiveUtils
[4af54fe1] LazyArtifacts
[b27032c2] LibCURL
[76f85450] LibGit2
[8f399da3] Libdl
[37e2e46d] LinearAlgebra
[56ddb016] Logging
[d6f4376e] Markdown
[a63ad114] Mmap
[ca575930] NetworkOptions
[44cfe95a] Pkg
[de0858da] Printf
[3fa0cd96] REPL
[9a3f8284] Random
[ea8e919c] SHA
[9e88b42a] Serialization
[1a1011a3] SharedArrays
[6462fe0b] Sockets
[2f01184e] SparseArrays
[10745b16] Statistics
[4607b0f0] SuiteSparse
[fa267f1f] TOML
[a4e569a6] Tar
[8dfed614] Test
[cf7118a7] UUIDs
[4ec0a83e] Unicode
[e66e0078] CompilerSupportLibraries_jll
[deac9b47] LibCURL_jll
[29816b5a] LibSSH2_jll
[c8ffd9c3] MbedTLS_jll
[14a3606d] MozillaCACerts_jll
[05823500] OpenLibm_jll
[efcefdf7] PCRE2_jll
[83775a58] Zlib_jll
[8e850ede] nghttp2_jll
[3f19e933] p7zip_jll