Bayesian Multinomial Logistic Regression

Multinomial logistic regression is an extension of logistic regression. Logistic regression is used to model problems in which there are exactly two possible discrete outcomes. Multinomial logistic regression is used to model problems in which there are two or more possible discrete outcomes.

In our example, we’ll be using the iris dataset. The iris multiclass problem aims to predict the species of a flower given measurements (in centimeters) of sepal length and width and petal length and width. There are three possible species: Iris setosa, Iris versicolor, and Iris virginica.

To start, let’s import all the libraries we’ll need.

# Load Turing.
using Turing

# Load RDatasets.
using RDatasets

# Load StatsPlots for visualizations and diagnostics.
using StatsPlots

# Functionality for splitting and normalizing the data.
using MLDataUtils: shuffleobs, splitobs, rescale!

# We need a softmax function which is provided by NNlib.
using NNlib: softmax

# Functionality for constructing arrays with identical elements efficiently.
using FillArrays

# Functionality for working with scaled identity matrices.
using LinearAlgebra

# Set a seed for reproducibility.
using Random
Random.seed!(0);

Data Cleaning & Set Up

Now we’re going to import our dataset. Twenty rows of the dataset are shown below so you can get a good feel for what kind of data we have.

# Import the "iris" dataset.
data = RDatasets.dataset("datasets", "iris");

# Show twenty random rows.
data[rand(1:size(data, 1), 20), :]
20×5 DataFrame
Row SepalLength SepalWidth PetalLength PetalWidth Species
Float64 Float64 Float64 Float64 Cat…
1 6.7 3.0 5.0 1.7 versicolor
2 7.2 3.0 5.8 1.6 virginica
3 5.7 3.8 1.7 0.3 setosa
4 5.1 3.5 1.4 0.2 setosa
5 6.4 3.2 4.5 1.5 versicolor
6 5.8 2.7 4.1 1.0 versicolor
7 7.7 3.0 6.1 2.3 virginica
8 4.9 3.0 1.4 0.2 setosa
9 5.8 2.6 4.0 1.2 versicolor
10 5.9 3.0 4.2 1.5 versicolor
11 5.2 2.7 3.9 1.4 versicolor
12 5.4 3.9 1.7 0.4 setosa
13 5.3 3.7 1.5 0.2 setosa
14 5.0 3.0 1.6 0.2 setosa
15 4.4 3.2 1.3 0.2 setosa
16 4.8 3.0 1.4 0.3 setosa
17 5.7 2.5 5.0 2.0 virginica
18 5.5 2.4 3.7 1.0 versicolor
19 4.9 2.4 3.3 1.0 versicolor
20 5.5 2.6 4.4 1.2 versicolor

In this data set, the outcome Species is currently coded as a string. We convert it to a numerical value by using indices 1, 2, and 3 to indicate species setosa, versicolor, and virginica, respectively.

# Recode the `Species` column.
species = ["setosa", "versicolor", "virginica"]
data[!, :Species_index] = indexin(data[!, :Species], species)

# Show twenty random rows of the new species columns
data[rand(1:size(data, 1), 20), [:Species, :Species_index]]
20×2 DataFrame
Row Species Species_index
Cat… Union…
1 versicolor 2
2 versicolor 2
3 virginica 3
4 versicolor 2
5 versicolor 2
6 setosa 1
7 virginica 3
8 virginica 3
9 virginica 3
10 setosa 1
11 virginica 3
12 setosa 1
13 versicolor 2
14 versicolor 2
15 versicolor 2
16 versicolor 2
17 setosa 1
18 versicolor 2
19 versicolor 2
20 versicolor 2

After we’ve done that tidying, it’s time to split our dataset into training and testing sets, and separate the features and target from the data. Additionally, we must rescale our feature variables so that they are centered around zero by subtracting each column by the mean and dividing it by the standard deviation. Without this step, Turing’s sampler will have a hard time finding a place to start searching for parameter estimates.

# Split our dataset 50%/50% into training/test sets.
trainset, testset = splitobs(shuffleobs(data), 0.5)

# Define features and target.
features = [:SepalLength, :SepalWidth, :PetalLength, :PetalWidth]
target = :Species_index

# Turing requires data in matrix and vector form.
train_features = Matrix(trainset[!, features])
test_features = Matrix(testset[!, features])
train_target = trainset[!, target]
test_target = testset[!, target]

# Standardize the features.
μ, σ = rescale!(train_features; obsdim=1)
rescale!(test_features, μ, σ; obsdim=1);

Model Declaration

Finally, we can define our model logistic_regression. It is a function that takes three arguments where

  • x is our set of independent variables;
  • y is the element we want to predict;
  • σ is the standard deviation we want to assume for our priors.

We select the setosa species as the baseline class (the choice does not matter). Then we create the intercepts and vectors of coefficients for the other classes against that baseline. More concretely, we create scalar intercepts intercept_versicolor and intersept_virginica and coefficient vectors coefficients_versicolor and coefficients_virginica with four coefficients each for the features SepalLength, SepalWidth, PetalLength and PetalWidth. We assume a normal distribution with mean zero and standard deviation σ as prior for each scalar parameter. We want to find the posterior distribution of these, in total ten, parameters to be able to predict the species for any given set of features.

# Bayesian multinomial logistic regression
@model function logistic_regression(x, y, σ)
    n = size(x, 1)
    length(y) == n ||
        throw(DimensionMismatch("number of observations in `x` and `y` is not equal"))

    # Priors of intercepts and coefficients.
    intercept_versicolor ~ Normal(0, σ)
    intercept_virginica ~ Normal(0, σ)
    coefficients_versicolor ~ MvNormal(Zeros(4), σ^2 * I)
    coefficients_virginica ~ MvNormal(Zeros(4), σ^2 * I)

    # Compute the likelihood of the observations.
    values_versicolor = intercept_versicolor .+ x * coefficients_versicolor
    values_virginica = intercept_virginica .+ x * coefficients_virginica
    for i in 1:n
        # the 0 corresponds to the base category `setosa`
        v = softmax([0, values_versicolor[i], values_virginica[i]])
        y[i] ~ Categorical(v)
    end
end;

Sampling

Now we can run our sampler. This time we’ll use NUTS to sample from our posterior.

setprogress!(false)
m = logistic_regression(train_features, train_target, 1)
chain = sample(m, NUTS(), MCMCThreads(), 1_500, 3)
Chains MCMC chain (1500×22×3 Array{Float64, 3}):

Iterations        = 751:1:2250
Number of chains  = 3
Samples per chain = 1500
Wall duration     = 19.7 seconds
Compute duration  = 16.74 seconds
parameters        = intercept_versicolor, intercept_virginica, coefficients_versicolor[1], coefficients_versicolor[2], coefficients_versicolor[3], coefficients_versicolor[4], coefficients_virginica[1], coefficients_virginica[2], coefficients_virginica[3], coefficients_virginica[4]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                  parameters      mean       std      mcse    ess_bulk    ess_ ⋯
                      Symbol   Float64   Float64   Float64     Float64     Flo ⋯

        intercept_versicolor    0.5970    0.5311    0.0080   4433.8339   2971. ⋯
         intercept_virginica   -0.6539    0.6899    0.0103   4517.7187   3245. ⋯
  coefficients_versicolor[1]    1.0183    0.6966    0.0103   4554.8119   3389. ⋯
  coefficients_versicolor[2]   -1.4152    0.5460    0.0088   3900.7382   2384. ⋯
  coefficients_versicolor[3]    1.0572    0.7892    0.0124   4140.7304   3289. ⋯
  coefficients_versicolor[4]    0.3482    0.7436    0.0118   3975.1399   2914. ⋯
   coefficients_virginica[1]    1.1327    0.7304    0.0106   4766.2613   3395. ⋯
   coefficients_virginica[2]   -0.7003    0.6054    0.0092   4370.2821   3079. ⋯
   coefficients_virginica[3]    1.9840    0.8676    0.0130   4431.8384   2686. ⋯
   coefficients_virginica[4]    2.6489    0.8073    0.0128   3961.1005   2825. ⋯
                                                               3 columns omitted

Quantiles
                  parameters      2.5%     25.0%     50.0%     75.0%     97.5% ⋯
                      Symbol   Float64   Float64   Float64   Float64   Float64 ⋯

        intercept_versicolor   -0.4281    0.2379    0.5917    0.9445    1.6546 ⋯
         intercept_virginica   -1.9785   -1.1372   -0.6545   -0.1947    0.6821 ⋯
  coefficients_versicolor[1]   -0.3293    0.5339    1.0152    1.4899    2.3830 ⋯
  coefficients_versicolor[2]   -2.5382   -1.7673   -1.4018   -1.0433   -0.3850 ⋯
  coefficients_versicolor[3]   -0.4601    0.5210    1.0424    1.5852    2.5953 ⋯
  coefficients_versicolor[4]   -1.1502   -0.1385    0.3565    0.8366    1.8186 ⋯
   coefficients_virginica[1]   -0.2805    0.6362    1.1222    1.6157    2.5715 ⋯
   coefficients_virginica[2]   -1.8708   -1.1108   -0.6966   -0.2914    0.4661 ⋯
   coefficients_virginica[3]    0.2758    1.4102    1.9992    2.5514    3.6480 ⋯
   coefficients_virginica[4]    1.0492    2.1095    2.6495    3.1902    4.2115 ⋯

The sample() call above assumes that you have at least nchains threads available in your Julia instance. If you do not, the multiple chains will run sequentially, and you may notice a warning. For more information, see the Turing documentation on sampling multiple chains.

Since we ran multiple chains, we may as well do a spot check to make sure each chain converges around similar points.

plot(chain)