Linear Regression

Turing is powerful when applied to complex hierarchical models, but it can also be put to task at common statistical procedures, like linear regression. This tutorial covers how to implement a linear regression model in Turing.

Set Up

We begin by importing all the necessary libraries.

# Import Turing.
using Turing

# Package for loading the data set.
using RDatasets

# Package for visualization.
using StatsPlots

# Functionality for splitting the data.
using MLUtils: splitobs

# Functionality for constructing arrays with identical elements efficiently.
using FillArrays

# Functionality for normalizing the data and evaluating the model predictions.
using StatsBase

# Functionality for working with scaled identity matrices.
using LinearAlgebra

# Set a seed for reproducibility.
using Random
Random.seed!(0);
setprogress!(false)

We will use the mtcars dataset from the RDatasets package. mtcars contains a variety of statistics on different car models, including their miles per gallon, number of cylinders, and horsepower, among others.

We want to know if we can construct a Bayesian linear regression model to predict the miles per gallon of a car, given the other statistics it has. Let us take a look at the data we have.

# Load the dataset.
data = RDatasets.dataset("datasets", "mtcars")

# Show the first six rows of the dataset.
first(data, 6)
6×12 DataFrame
Row Model MPG Cyl Disp HP DRat WT QSec VS AM Gear Carb
String31 Float64 Int64 Float64 Int64 Float64 Float64 Float64 Int64 Int64 Int64 Int64
1 Mazda RX4 21.0 6 160.0 110 3.9 2.62 16.46 0 1 4 4
2 Mazda RX4 Wag 21.0 6 160.0 110 3.9 2.875 17.02 0 1 4 4
3 Datsun 710 22.8 4 108.0 93 3.85 2.32 18.61 1 1 4 1
4 Hornet 4 Drive 21.4 6 258.0 110 3.08 3.215 19.44 1 0 3 1
5 Hornet Sportabout 18.7 8 360.0 175 3.15 3.44 17.02 0 0 3 2
6 Valiant 18.1 6 225.0 105 2.76 3.46 20.22 1 0 3 1
size(data)
(32, 12)

The next step is to get our data ready for testing. We’ll split the mtcars dataset into two subsets, one for training our model and one for evaluating our model. Then, we separate the targets we want to learn (MPG, in this case) and standardize the datasets by subtracting each column’s means and dividing by the standard deviation of that column. The resulting data is not very familiar looking, but this standardization process helps the sampler converge far easier.

# Remove the model column.
select!(data, Not(:Model))

# Split our dataset 70%/30% into training/test sets.
trainset, testset = map(DataFrame, splitobs(data; at=0.7, shuffle=true))

# Turing requires data in matrix form.
target = :MPG
train = Matrix(select(trainset, Not(target)))
test = Matrix(select(testset, Not(target)))
train_target = trainset[:, target]
test_target = testset[:, target]

# Standardize the features.
dt_features = fit(ZScoreTransform, train; dims=1)
StatsBase.transform!(dt_features, train)
StatsBase.transform!(dt_features, test)

# Standardize the targets.
dt_targets = fit(ZScoreTransform, train_target)
StatsBase.transform!(dt_targets, train_target)
StatsBase.transform!(dt_targets, test_target);

Model Specification

In a traditional frequentist model using OLS, our model might look like:

\[ \mathrm{MPG}_i = \alpha + \boldsymbol{\beta}^\mathsf{T}\boldsymbol{X_i} \]

where \(\boldsymbol{\beta}\) is a vector of coefficients and \(\boldsymbol{X}\) is a vector of inputs for observation \(i\). The Bayesian model we are more concerned with is the following:

\[ \mathrm{MPG}_i \sim \mathcal{N}(\alpha + \boldsymbol{\beta}^\mathsf{T}\boldsymbol{X_i}, \sigma^2) \]

where \(\alpha\) is an intercept term common to all observations, \(\boldsymbol{\beta}\) is a coefficient vector, \(\boldsymbol{X_i}\) is the observed data for car \(i\), and \(\sigma^2\) is a common variance term.

For \(\sigma^2\), we assign a prior of truncated(Normal(0, 100); lower=0). This is consistent with Andrew Gelman’s recommendations on noninformative priors for variance. The intercept term (\(\alpha\)) is assumed to be normally distributed with a mean of zero and a variance of three. This represents our assumptions that miles per gallon can be explained mostly by our assorted variables, but a high variance term indicates our uncertainty about that. Each coefficient is assumed to be normally distributed with a mean of zero and a variance of 10. We do not know that our coefficients are different from zero, and we don’t know which ones are likely to be the most important, so the variance term is quite high. Lastly, each observation \(y_i\) is distributed according to the calculated mu term given by \(\alpha + \boldsymbol{\beta}^\mathsf{T}\boldsymbol{X_i}\).

# Bayesian linear regression.
@model function linear_regression(x, y)
    # Set variance prior.
    σ² ~ truncated(Normal(0, 100); lower=0)

    # Set intercept prior.
    intercept ~ Normal(0, sqrt(3))

    # Set the priors on our coefficients.
    nfeatures = size(x, 2)
    coefficients ~ MvNormal(Zeros(nfeatures), 10.0 * I)

    # Calculate all the mu terms.
    mu = intercept .+ x * coefficients
    return y ~ MvNormal(mu, σ² * I)
end
linear_regression (generic function with 2 methods)

With our model specified, we can call the sampler. We will use the No U-Turn Sampler (NUTS) here.

model = linear_regression(train, train_target)
chain = sample(model, NUTS(), 5_000)
┌ Info: Found initial step size
└   ϵ = 0.4
Chains MCMC chain (5000×24×1 Array{Float64, 3}):

Iterations        = 1001:1:6000
Number of chains  = 1
Samples per chain = 5000
Wall duration     = 10.71 seconds
Compute duration  = 10.71 seconds
parameters        = σ², intercept, coefficients[1], coefficients[2], coefficients[3], coefficients[4], coefficients[5], coefficients[6], coefficients[7], coefficients[8], coefficients[9], coefficients[10]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
        parameters      mean       std      mcse    ess_bulk    ess_tail       ⋯
            Symbol   Float64   Float64   Float64     Float64     Float64   Flo ⋯

                σ²    0.4654    0.2830    0.0072   1182.3902    800.1722    1. ⋯
         intercept    0.0006    0.1409    0.0020   5297.3398   2635.8792    1. ⋯
   coefficients[1]   -0.5122    0.5753    0.0099   3378.2795   2872.2853    1. ⋯
   coefficients[2]    0.3519    0.7824    0.0171   2121.2881   2464.2577    1. ⋯
   coefficients[3]   -0.4526    0.4888    0.0083   3473.4131   3529.7692    1. ⋯
   coefficients[4]    0.0052    0.2693    0.0041   4341.3733   3281.4084    1. ⋯
   coefficients[5]   -0.2071    0.5516    0.0125   1972.6773   2293.0393    1. ⋯
   coefficients[6]   -0.0493    0.4620    0.0093   2496.7348   2777.7802    1. ⋯
   coefficients[7]   -0.0974    0.4281    0.0089   2333.8838   2930.1320    1. ⋯
   coefficients[8]    0.1086    0.3920    0.0084   2173.6208   2918.4830    1. ⋯
   coefficients[9]    0.2228    0.3676    0.0076   2336.5074   2941.0944    1. ⋯
  coefficients[10]   -0.1387    0.4837    0.0116   1758.8444   2153.1698    1. ⋯
                                                               2 columns omitted

Quantiles
        parameters      2.5%     25.0%     50.0%     75.0%     97.5%
            Symbol   Float64   Float64   Float64   Float64   Float64

                σ²    0.1799    0.2947    0.3945    0.5513    1.1675
         intercept   -0.2771   -0.0881   -0.0001    0.0888    0.2790
   coefficients[1]   -1.6511   -0.8726   -0.5026   -0.1500    0.6035
   coefficients[2]   -1.2743   -0.1287    0.3778    0.8433    1.8649
   coefficients[3]   -1.4243   -0.7659   -0.4650   -0.1265    0.5074
   coefficients[4]   -0.5416   -0.1583    0.0089    0.1727    0.5439
   coefficients[5]   -1.2737   -0.5533   -0.2143    0.1251    0.9322
   coefficients[6]   -0.9572   -0.3353   -0.0483    0.2399    0.8837
   coefficients[7]   -0.9548   -0.3699   -0.0932    0.1709    0.7449
   coefficients[8]   -0.6537   -0.1366    0.1035    0.3558    0.8818
   coefficients[9]   -0.4944   -0.0143    0.2181    0.4562    0.9522
  coefficients[10]   -1.1282   -0.4366   -0.1254    0.1602    0.7969

We can also check the densities and traces of the parameters visually using the plot functionality.

plot(chain)