Turing is powerful when applied to complex hierarchical models, but it can also be put to task at common statistical procedures, like linear regression. This tutorial covers how to implement a linear regression model in Turing.
Set Up
We begin by importing all the necessary libraries.
# Import Turing.usingTuring# Package for loading the data set.usingRDatasets# Package for visualization.usingStatsPlots# Functionality for splitting the data.usingMLUtils: splitobs# Functionality for constructing arrays with identical elements efficiently.usingFillArrays# Functionality for normalizing the data and evaluating the model predictions.usingStatsBase# Functionality for working with scaled identity matrices.usingLinearAlgebra# Set a seed for reproducibility.usingRandomRandom.seed!(0);
Precompiling Turing...
777.6 ms ? OptimizationBase
1364.8 ms ? Optimization
2096.7 ms ? OptimizationOptimJL
Info Given Turing was explicitly requested, output will be shown live
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
5469.1 ms ? Turing
5641.8 ms ? Turing → TuringOptimExt
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Optimization...
804.0 ms ? OptimizationBaseInfo Given Optimization was explicitly requested, output will be shown live
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
1387.1 ms ? Optimization
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling OptimizationBase...
Info Given OptimizationBase was explicitly requested, output will be shown live
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
781.2 ms ? OptimizationBase
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
┌ Warning: Replacing docs for `CommonSolve.solve :: Tuple{SciMLBase.OptimizationProblem, Any, Vararg{Any}}` in module `OptimizationBase`
└ @ Base.Docs docs/Docs.jl:243
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
┌ Warning: Replacing docs for `CommonSolve.init :: Tuple{SciMLBase.OptimizationProblem, Any, Vararg{Any}}` in module `OptimizationBase`
└ @ Base.Docs docs/Docs.jl:243┌ Warning: Replacing docs for `CommonSolve.solve! :: Tuple{SciMLBase.AbstractOptimizationCache}` in module `OptimizationBase`
└ @ Base.Docs docs/Docs.jl:243Precompiling OptimizationOptimJL...
787.9 ms ? OptimizationBase
984.6 ms ? Optimization
Info Given OptimizationOptimJL was explicitly requested, output will be shown live
┌ Warning: Module Optimization with build ID ffffffff-ffff-ffff-6776-fc56b69227e1 is missing from the cache.
│ This may mean Optimization [7f7a1694-90dd-40f0-9382-eb1efda571ba] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
1113.9 ms ? OptimizationOptimJL
┌ Warning: Module Optimization with build ID ffffffff-ffff-ffff-6776-fc56b69227e1 is missing from the cache.
│ This may mean Optimization [7f7a1694-90dd-40f0-9382-eb1efda571ba] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541Precompiling TuringOptimExt...
780.6 ms ? OptimizationBase
987.5 ms ? Optimization
1110.4 ms ? OptimizationOptimJL
3658.2 ms ? Turing
Info Given TuringOptimExt was explicitly requested, output will be shown live
┌ Warning: Module Turing with build ID ffffffff-ffff-ffff-3c84-df27f3f73487 is missing from the cache.
│ This may mean Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
609.3 ms ? Turing → TuringOptimExt
┌ Warning: Module Turing with build ID ffffffff-ffff-ffff-3c84-df27f3f73487 is missing from the cache.
│ This may mean Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541Precompiling OptimizationMLUtilsExt...
793.5 ms ? OptimizationBaseInfo Given OptimizationMLUtilsExt was explicitly requested, output will be shown live
┌ Warning: Module OptimizationBase with build ID ffffffff-ffff-ffff-03f1-1d0d1ac6abb4 is missing from the cache.
│ This may mean OptimizationBase [bca83a33-5cc9-4baa-983d-23429ab6bcbb] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
1290.3 ms ? OptimizationBase → OptimizationMLUtilsExt┌ Warning: Module OptimizationBase with build ID ffffffff-ffff-ffff-03f1-1d0d1ac6abb4 is missing from the cache.
│ This may mean OptimizationBase [bca83a33-5cc9-4baa-983d-23429ab6bcbb] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
setprogress!(false)
We will use the mtcars dataset from the RDatasets package. mtcars contains a variety of statistics on different car models, including their miles per gallon, number of cylinders, and horsepower, among others.
We want to know if we can construct a Bayesian linear regression model to predict the miles per gallon of a car, given the other statistics it has. Let us take a look at the data we have.
# Load the dataset.data = RDatasets.dataset("datasets", "mtcars")# Show the first six rows of the dataset.first(data, 6)
6×12 DataFrame
Row
Model
MPG
Cyl
Disp
HP
DRat
WT
QSec
VS
AM
Gear
Carb
String31
Float64
Int64
Float64
Int64
Float64
Float64
Float64
Int64
Int64
Int64
Int64
1
Mazda RX4
21.0
6
160.0
110
3.9
2.62
16.46
0
1
4
4
2
Mazda RX4 Wag
21.0
6
160.0
110
3.9
2.875
17.02
0
1
4
4
3
Datsun 710
22.8
4
108.0
93
3.85
2.32
18.61
1
1
4
1
4
Hornet 4 Drive
21.4
6
258.0
110
3.08
3.215
19.44
1
0
3
1
5
Hornet Sportabout
18.7
8
360.0
175
3.15
3.44
17.02
0
0
3
2
6
Valiant
18.1
6
225.0
105
2.76
3.46
20.22
1
0
3
1
size(data)
(32, 12)
The next step is to get our data ready for testing. We’ll split the mtcars dataset into two subsets, one for training our model and one for evaluating our model. Then, we separate the targets we want to learn (MPG, in this case) and standardize the datasets by subtracting each column’s means and dividing by the standard deviation of that column. The resulting data is not very familiar looking, but this standardization process helps the sampler converge far easier.
# Remove the model column.select!(data, Not(:Model))# Split our dataset 70%/30% into training/test sets.trainset, testset =map(DataFrame, splitobs(data; at=0.7, shuffle=true))# Turing requires data in matrix form.target =:MPGtrain =Matrix(select(trainset, Not(target)))test =Matrix(select(testset, Not(target)))train_target = trainset[:, target]test_target = testset[:, target]# Standardize the features.dt_features =fit(ZScoreTransform, train; dims=1)StatsBase.transform!(dt_features, train)StatsBase.transform!(dt_features, test)# Standardize the targets.dt_targets =fit(ZScoreTransform, train_target)StatsBase.transform!(dt_targets, train_target)StatsBase.transform!(dt_targets, test_target);
Model Specification
In a traditional frequentist model using OLS, our model might look like:
where \(\boldsymbol{\beta}\) is a vector of coefficients and \(\boldsymbol{X}\) is a vector of inputs for observation \(i\). The Bayesian model we are more concerned with is the following:
where \(\alpha\) is an intercept term common to all observations, \(\boldsymbol{\beta}\) is a coefficient vector, \(\boldsymbol{X_i}\) is the observed data for car \(i\), and \(\sigma^2\) is a common variance term.
For \(\sigma^2\), we assign a prior of truncated(Normal(0, 100); lower=0). This is consistent with Andrew Gelman’s recommendations on noninformative priors for variance. The intercept term (\(\alpha\)) is assumed to be normally distributed with a mean of zero and a variance of three. This represents our assumptions that miles per gallon can be explained mostly by our assorted variables, but a high variance term indicates our uncertainty about that. Each coefficient is assumed to be normally distributed with a mean of zero and a variance of 10. We do not know that our coefficients are different from zero, and we don’t know which ones are likely to be the most important, so the variance term is quite high. Lastly, each observation \(y_i\) is distributed according to the calculated mu term given by \(\alpha + \boldsymbol{\beta}^\mathsf{T}\boldsymbol{X_i}\).
# Bayesian linear regression.@modelfunctionlinear_regression(x, y)# Set variance prior. σ² ~truncated(Normal(0, 100); lower=0)# Set intercept prior. intercept ~Normal(0, sqrt(3))# Set the priors on our coefficients. nfeatures =size(x, 2) coefficients ~MvNormal(Zeros(nfeatures), 10.0* I)# Calculate all the mu terms. mu = intercept .+ x * coefficientsreturn y ~MvNormal(mu, σ² * I)end
linear_regression (generic function with 2 methods)
With our model specified, we can call the sampler. We will use the No U-Turn Sampler (NUTS) here.
model =linear_regression(train, train_target)chain =sample(model, NUTS(), 5_000)
We can also check the densities and traces of the parameters visually using the plot functionality.
plot(chain)
It looks like all parameters have converged.
Comparing to OLS
A satisfactory test of our model is to evaluate how well it predicts. Importantly, we want to compare our model to existing tools like OLS. The code below uses the GLM.jl package to generate a traditional OLS multiple regression model on the same data as our probabilistic model.
# Import the GLM package.usingGLM# Perform multiple regression OLS.train_with_intercept =hcat(ones(size(train, 1)), train)ols =lm(train_with_intercept, train_target)# Compute predictions on the training data set and unstandardize them.train_prediction_ols = GLM.predict(ols)StatsBase.reconstruct!(dt_targets, train_prediction_ols)# Compute predictions on the test data set and unstandardize them.test_with_intercept =hcat(ones(size(test, 1)), test)test_prediction_ols = GLM.predict(ols, test_with_intercept)StatsBase.reconstruct!(dt_targets, test_prediction_ols);
The function below accepts a chain and an input matrix and calculates predictions. We use the samples of the model parameters in the chain starting with sample 200.
# Make a prediction given an input vector.functionprediction(chain, x) p =get_params(chain[200:end, :, :]) targets = p.intercept'.+ x *reduce(hcat, p.coefficients)'returnvec(mean(targets; dims=2))end
prediction (generic function with 1 method)
When we make predictions, we unstandardize them so they are more understandable.
# Calculate the predictions for the training and testing sets and unstandardize them.train_prediction_bayes =prediction(chain, train)StatsBase.reconstruct!(dt_targets, train_prediction_bayes)test_prediction_bayes =prediction(chain, test)StatsBase.reconstruct!(dt_targets, test_prediction_bayes)# Show the predictions on the test data set.DataFrame(; MPG=testset[!, target], Bayes=test_prediction_bayes, OLS=test_prediction_ols)
10×3 DataFrame
Row
MPG
Bayes
OLS
Float64
Float64
Float64
1
33.9
26.7675
26.804
2
21.0
22.3414
22.4669
3
21.4
20.4986
20.5666
4
26.0
28.8645
28.9169
5
15.0
11.6032
11.584
6
10.4
13.6112
13.7006
7
30.4
27.3355
27.4661
8
10.4
14.3843
14.5346
9
18.7
17.211
17.2897
10
17.3
14.6529
14.6084
Now let’s evaluate the loss for each method, and each prediction set. We will use the mean squared error to evaluate loss, given by \[
\mathrm{MSE} = \frac{1}{n} \sum_{i=1}^n {(y_i - \hat{y_i})^2}
\] where \(y_i\) is the actual value (true MPG) and \(\hat{y_i}\) is the predicted value using either OLS or Bayesian linear regression. A lower SSE indicates a closer fit to the data.