Multinomial Logistic Regression

Multinomial logistic regression is an extension of logistic regression. Logistic regression is used to model problems in which there are exactly two possible discrete outcomes. Multinomial logistic regression is used to model problems in which there are two or more possible discrete outcomes.

In our example, we’ll be using the iris dataset. The iris multiclass problem aims to predict the species of a flower given measurements (in centimetres) of sepal length and width and petal length and width. There are three possible species: Iris setosa, Iris versicolor, and Iris virginica.

To start, let’s import all the libraries we’ll need.

using Turing
using MCMCChains: Chains
using StatsPlots
# Functionality for splitting and normalising the data.
using MLUtils: shuffleobs, splitobs, load_iris
using StatsBase: fit, ZScoreTransform, transform!
# We need a softmax function which is provided by NNlib.
using LogExpFunctions: softmax
# Functionality for constructing arrays with identical elements efficiently.
using FillArrays

# Set a seed for reproducibility.
using Random
Random.seed!(0);

Data Cleaning and Set Up

Now we’re going to import our dataset. Twenty rows of the dataset are shown below so you can get a good feel for what kind of data we have.

# Import the "iris" dataset.
X, Y = load_iris()
nobs = size(X, 2)

# Show 10 random rows of the outcomes.
Y[rand(1:nobs, 10)]
10-element Vector{String}:
 "versicolor"
 "setosa"
 "virginica"
 "setosa"
 "versicolor"
 "setosa"
 "setosa"
 "virginica"
 "setosa"
 "setosa"

In this data set, the outcome Species is currently coded as a string. We convert it to a numerical value by using indices 1, 2, and 3 to indicate species setosa, versicolor, and virginica, respectively.

species = ["setosa", "versicolor", "virginica"]
Y = Vector{Int64}(indexin(Y, species))
Y[rand(1:nobs, 10)]
10-element Vector{Int64}:
 2
 3
 1
 2
 2
 3
 3
 1
 2
 2

After we’ve done that tidying, it’s time to split our dataset into training and testing sets, and separate the features and target from the data. Additionally, we must rescale our feature variables so that they are centred around zero by subtracting each column by the mean and dividing it by the standard deviation. This standardisation improves sampler efficiency by ensuring all features are on comparable scales.

# Split our dataset 50%/50% into training/test sets.
(train_features, train_target), (test_features, test_target) = splitobs(shuffleobs((X, Y)); at=0.5)

# Standardise the features.
dt = fit(ZScoreTransform, train_features; dims=2)
transform!(dt, train_features)
transform!(dt, test_features)
4×75 view(::Matrix{Float64}, :, [18, 12, 118, 111, 91, 53, 69, 49, 48, 136  …  36, 140, 110, 146, 59, 17, 100, 142, 82, 107]) with eltype Float64:
 -0.960233  -1.35377   2.45043  0.876278  …  1.401      -0.435515   -1.22259
  1.02261    0.787344  1.72839  0.31682      0.0815576  -1.56528    -1.33002
 -1.35857   -1.24192   1.73263  0.799434     0.799434   -0.0171085   0.449487
 -1.23407   -1.37059   1.35966  1.08664      1.49618    -0.278486    0.677102

Model Declaration

Finally, we can define our model logistic_regression. It is a function that takes three arguments where

  • x is our set of independent variables;
  • y is the element we want to predict;
  • σ is the standard deviation we want to assume for our priors.

We select the setosa species as the baseline class (the choice does not matter). Then we create the intercepts and vectors of coefficients for the other classes against that baseline. More concretely, we create scalar intercepts intercept_versicolor and intersept_virginica and coefficient vectors coefficients_versicolor and coefficients_virginica with four coefficients, one for each feature. This gives us a total of 10 parameters to estimate. We assume a normal distribution with mean zero and standard deviation σ as prior for each scalar parameter. We want to find the posterior distribution of these parameters to be able to predict the species for any given set of features.

# Bayesian multinomial logistic regression
@model function logistic_regression(x, y, σ)
    n = size(x, 2)
    length(y) == n ||
        throw(DimensionMismatch("number of observations in `x` and `y` is not equal"))

    # Priors of intercepts and coefficients.
    intercept_versicolor ~ Normal(0, σ)
    intercept_virginica ~ Normal(0, σ)
    coefficients_versicolor ~ MvNormal(Zeros(4), σ^2 * I)
    coefficients_virginica ~ MvNormal(Zeros(4), σ^2 * I)

    # Compute the likelihood of the observations.
    values_versicolor = intercept_versicolor .+ (coefficients_versicolor' * x)
    values_virginica = intercept_virginica .+ (coefficients_virginica' * x)
    for i in 1:n
        # the 0 corresponds to the base category `setosa`
        v = softmax([0, values_versicolor[i], values_virginica[i]])
        y[i] ~ Categorical(v)
    end
end;

Sampling

Now we can run our sampler. Here we’ll use NUTS to sample from our posterior.

setprogress!(false)
m = logistic_regression(train_features, train_target, 1)
chain = sample(m, NUTS(), MCMCThreads(), 1_500, 3)
Warning: Only a single thread available: MCMC chains are not sampled in parallel
@ AbstractMCMC ~/.julia/packages/AbstractMCMC/oqm6Y/src/sample.jl:544
Info: Found initial step size
  ϵ = 0.8
Info: Found initial step size
  ϵ = 0.8
Info: Found initial step size
  ϵ = 0.8
Chains MCMC chain (1500×24×3 Array{Float64, 3}):

Iterations        = 751:1:2250
Number of chains  = 3
Samples per chain = 1500
Wall duration     = 14.46 seconds
Compute duration  = 9.84 seconds
parameters        = intercept_versicolor, intercept_virginica, coefficients_versicolor[1], coefficients_versicolor[2], coefficients_versicolor[3], coefficients_versicolor[4], coefficients_virginica[1], coefficients_virginica[2], coefficients_virginica[3], coefficients_virginica[4]
internals         = n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, logprior, loglikelihood, logjoint

Use `describe(chains)` for summary statistics and quantiles.

Sampling With Multiple Threads

The sample() call above assumes that you have at least nchains threads available in your Julia instance. If you do not, the multiple chains will run sequentially, and you may notice a warning. For more information, see the Turing documentation on sampling multiple chains.

Since we ran multiple chains, we may as well do a spot check to make sure each chain converges around similar points.

plot(chain)

Looks good!

We can also use the corner function from StatsPlots to show the distributions of the various parameters of our multinomial logistic regression. corner(chain) will show the distributions of all parameters, but here we will only show the first three to avoid cluttering the plot.

corner(chain, MCMCChains.namesingroup(chain, :coefficients_versicolor))

Fortunately the corner plots appear to demonstrate unimodal distributions for each of our parameters, so it should be straightforward to take the means of each parameter’s sampled values to estimate our model to make predictions.

Making Predictions

How do we test how well the model actually predicts which of the three classes an iris flower belongs to? We need to build a prediction function that takes the test dataset and runs it through the average parameter calculated during sampling.

The prediction function below takes a matrix and a Chains object. It computes the mean of the sampled parameters and calculates the species with the highest probability for each observation. Note that we do not have to evaluate the softmax function since it does not affect the order of its inputs.

function prediction(x::AbstractMatrix{<:Real}, chain::MCMCChains.Chains)
    # Pull the means from each parameter's sampled values in the chain.
    intercept_versicolor = mean(chain, :intercept_versicolor)
    intercept_virginica = mean(chain, :intercept_virginica)
    coefficients_versicolor = [
        mean(chain, k) for k in MCMCChains.namesingroup(chain, :coefficients_versicolor)
    ]
    coefficients_virginica = [
        mean(chain, k) for k in MCMCChains.namesingroup(chain, :coefficients_virginica)
    ]

    # Compute the index of the species with the highest probability for each observation.
    values_versicolor = intercept_versicolor .+ (coefficients_versicolor' * x)
    values_virginica = intercept_virginica .+ (coefficients_virginica' * x)
    species_indices = [
        argmax((0, x, y)) for (x, y) in zip(values_versicolor, values_virginica)
    ]

    return species_indices
end;

Let’s see how we did! We run the test matrix through the prediction function, and compute the accuracy for our prediction.

# Make the predictions.
predictions = prediction(test_features, chain)

# Calculate accuracy for our test set.
mean(predictions .== test_target)
0.3342222222222222

Perhaps more important is to see the accuracy per class.

for s in 1:3
    rows = test_target .== s
    println("Number of `", species[s], "`: ", count(rows))
    println(
        "Percentage of `",
        species[s],
        "` predicted correctly: ",
        mean(predictions[rows] .== test_target[rows])
    )
end
Number of `setosa`: 26
Percentage of `setosa` predicted correctly: 1.0
Number of `versicolor`: 21
Percentage of `versicolor` predicted correctly: 1.0
Number of `virginica`: 28
Percentage of `virginica` predicted correctly: 0.8928571428571429
Back to top