In this tutorial, we demonstrate how one can implement a Bayesian Neural Network using a combination of Turing and Lux, a suite of machine learning tools. We will use Lux to specify the neural network’s layers and Turing to implement the probabilistic inference, with the goal of implementing a classification algorithm.
We will begin with importing the relevant libraries.
Precompiling Turing...
785.4 ms ? OptimizationBase
1355.6 ms ? Optimization
2080.3 ms ? OptimizationOptimJL
Info Given Turing was explicitly requested, output will be shown live
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
5505.7 ms ? Turing
5664.8 ms ? Turing → TuringOptimExt
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Optimization...
803.5 ms ? OptimizationBaseInfo Given Optimization was explicitly requested, output will be shown live
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
1430.1 ms ? Optimization
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling OptimizationBase...
Info Given OptimizationBase was explicitly requested, output will be shown live
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
790.1 ms ? OptimizationBase
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
┌ Warning: Replacing docs for `CommonSolve.solve :: Tuple{SciMLBase.OptimizationProblem, Any, Vararg{Any}}` in module `OptimizationBase`
└ @ Base.Docs docs/Docs.jl:243
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
┌ Warning: Replacing docs for `CommonSolve.init :: Tuple{SciMLBase.OptimizationProblem, Any, Vararg{Any}}` in module `OptimizationBase`
└ @ Base.Docs docs/Docs.jl:243┌ Warning: Replacing docs for `CommonSolve.solve! :: Tuple{SciMLBase.AbstractOptimizationCache}` in module `OptimizationBase`
└ @ Base.Docs docs/Docs.jl:243Precompiling OptimizationOptimJL...
795.3 ms ? OptimizationBase
973.8 ms ? Optimization
Info Given OptimizationOptimJL was explicitly requested, output will be shown live
┌ Warning: Module Optimization with build ID ffffffff-ffff-ffff-c197-056b68ffbf33 is missing from the cache.
│ This may mean Optimization [7f7a1694-90dd-40f0-9382-eb1efda571ba] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
1110.4 ms ? OptimizationOptimJL
┌ Warning: Module Optimization with build ID ffffffff-ffff-ffff-c197-056b68ffbf33 is missing from the cache.
│ This may mean Optimization [7f7a1694-90dd-40f0-9382-eb1efda571ba] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541Precompiling TuringOptimExt...
784.4 ms ? OptimizationBase
980.2 ms ? Optimization
1105.7 ms ? OptimizationOptimJL
3656.2 ms ? Turing
Info Given TuringOptimExt was explicitly requested, output will be shown live
┌ Warning: Module Turing with build ID ffffffff-ffff-ffff-0aa6-705c3dd425f5 is missing from the cache.
│ This may mean Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
615.9 ms ? Turing → TuringOptimExt
┌ Warning: Module Turing with build ID ffffffff-ffff-ffff-0aa6-705c3dd425f5 is missing from the cache.
│ This may mean Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541Precompiling OptimizationMLDataDevicesExt...
779.9 ms ? OptimizationBaseInfo Given OptimizationMLDataDevicesExt was explicitly requested, output will be shown live
┌ Warning: Module OptimizationBase with build ID ffffffff-ffff-ffff-14a7-9df168a653a5 is missing from the cache.
│ This may mean OptimizationBase [bca83a33-5cc9-4baa-983d-23429ab6bcbb] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
650.3 ms ? OptimizationBase → OptimizationMLDataDevicesExt┌ Warning: Module OptimizationBase with build ID ffffffff-ffff-ffff-14a7-9df168a653a5 is missing from the cache.
│ This may mean OptimizationBase [bca83a33-5cc9-4baa-983d-23429ab6bcbb] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we will be working with.
# Number of points to generateN =80M =round(Int, N /4)rng =Random.default_rng()Random.seed!(rng, 1234)# Generate artificial datax1s =rand(rng, Float32, M) *4.5f0;x2s =rand(rng, Float32, M) *4.5f0;xt1s =Array([[x1s[i] +0.5f0; x2s[i] +0.5f0] for i in1:M])x1s =rand(rng, Float32, M) *4.5f0;x2s =rand(rng, Float32, M) *4.5f0;append!(xt1s, Array([[x1s[i] -5.0f0; x2s[i] -5.0f0] for i in1:M]))x1s =rand(rng, Float32, M) *4.5f0;x2s =rand(rng, Float32, M) *4.5f0;xt0s =Array([[x1s[i] +0.5f0; x2s[i] -5.0f0] for i in1:M])x1s =rand(rng, Float32, M) *4.5f0;x2s =rand(rng, Float32, M) *4.5f0;append!(xt0s, Array([[x1s[i] -5.0f0; x2s[i] +0.5f0] for i in1:M]))# Store all the data for laterxs = [xt1s; xt0s]ts = [ones(2* M); zeros(2* M)]# Plot data points.functionplot_data() x1 =map(e->e[1], xt1s) y1 =map(e->e[2], xt1s) x2 =map(e->e[1], xt0s) y2 =map(e->e[2], xt0s) Plots.scatter(x1, y1; color="red", clim=(0, 1))return Plots.scatter!(x2, y2; color="blue", clim=(0, 1))endplot_data()
Building a Neural Network
The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn_initial we created has two hidden layers with tanh activations and one output layer with sigmoid (σ) activation, as shown below.
The nn_initial is an instance that acts as a function and can take data as inputs and output predictions. We will define distributions on the neural network parameters.
# Construct a neural network using Luxnn_initial =Chain(Dense(2=>3, tanh), Dense(3=>2, tanh), Dense(2=>1, σ))# Initialize the model weights and stateps, st = Lux.setup(rng, nn_initial)Lux.parameterlength(nn_initial) # number of parameters in NN
20
The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters vector represents all parameters of our neural net (weights and biases).
# Create a regularization term and a Gaussian prior variance term.alpha =0.09sigma =sqrt(1.0/ alpha)
3.3333333333333335
We also define a function to construct a named tuple from a vector of sampled parameters. (We could use ComponentArrays here and broadcast to avoid doing this, but this way avoids introducing an extra dependency.)
functionvector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)@assertlength(ps_new) == Lux.parameterlength(ps) i =1functionget_ps(x) z =reshape(view(ps_new, i:(i +length(x) -1)), size(x)) i +=length(x)return zendreturnfmap(get_ps, ps)end
vector_to_parameters (generic function with 1 method)
To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.
const nn =StatefulLuxLayer{true}(nn_initial, nothing, st)# Specify the probabilistic model.@modelfunctionbayes_nn(xs, ts; sigma = sigma, ps = ps, nn = nn)# Sample the parameters nparameters = Lux.parameterlength(nn_initial) parameters ~MvNormal(zeros(nparameters), Diagonal(abs2.(sigma .*ones(nparameters))))# Forward NN to make predictions preds = Lux.apply(nn, xs, f32(vector_to_parameters(parameters, ps)))# Observe each prediction.for i ineachindex(ts) ts[i] ~Bernoulli(preds[i])endend
bayes_nn (generic function with 2 methods)
Inference can now be performed by calling sample. We use the NUTS Hamiltonian Monte Carlo sampler here.
Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We’ll use these primarily to determine how good our model’s classifier is.
# Extract all weight and bias parameters.θ = MCMCChains.group(ch, :parameters).value;
Prediction Visualization
We can use MAP estimation to classify our population by using the set of weights that provided the highest log posterior.
# A helper to run the nn through data `x` using parameters `θ`nn_forward(x, θ) =nn(x, vector_to_parameters(θ, ps))# Plot the data we have.fig =plot_data()# Find the index that provided the highest log posterior in the chain._, i =findmax(ch[:lp])# Extract the max row value from i.i = i.I[1]# Plot the posterior distribution with a contour plotx1_range =collect(range(-6; stop=6, length=25))x2_range =collect(range(-6; stop=6, length=25))Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)fig
The contour plot above shows that the MAP method is not too bad at classifying our data.
The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.
# Return the average predicted value across# multiple weights.functionnn_predict(x, θ, num) num =min(num, size(θ, 1)) # make sure num does not exceed the number of samplesreturnmean([first(nn_forward(x, view(θ, i, :))) for i in1:10:num])end
nn_predict (generic function with 1 method)
Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier—those regions between cluster boundaries.
# Plot the average prediction.fig =plot_data()n_end =1500x1_range =collect(range(-6; stop=6, length=25))x2_range =collect(range(-6; stop=6, length=25))Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)fig
Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 1,000.
# Number of iterations to plot.n_end =500anim =@giffor i in1:n_endplot_data() Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]contour!(x1_range, x2_range, Z; title="Iteration $i", clim=(0, 1))end every 5
[ Info: Saved animation to /tmp/jl_iZ4nm9gfA3.gif
This has been an introduction to the applications of Turing and Lux in defining Bayesian neural networks.