Hidden Markov Models

This tutorial illustrates training Bayesian hidden Markov models (HMMs) using Turing. The main goals are learning the transition matrix, emission parameter, and hidden states. For a more rigorous academic overview of hidden Markov models, see An Introduction to Hidden Markov Models and Bayesian Networks (Ghahramani, 2001).

In this tutorial, we assume there are \(k\) discrete hidden states; the observations are continuous and normally distributed - centered around the hidden states. This assumption reduces the number of parameters to be estimated in the emission matrix.

Let’s load the libraries we’ll need, and set a random seed for reproducibility.

# Load libraries.
using Turing, StatsPlots, Random, Bijectors

# Set a random seed
Random.seed!(12345678);

Simple State Detection

In this example, we’ll use something where the states and emission parameters are straightforward.

# Define the emission parameter.
y = [fill(1.0, 6)..., fill(2.0, 6)..., fill(3.0, 7)...,
  fill(2.0, 4)..., fill(1.0, 7)...]
N = length(y);
K = 3;

# Plot the data we just made.
plot(y; xlim=(0, 30), ylim=(-1, 5), size=(500, 250), legend = false)
scatter!(y, color = :blue; xlim=(0, 30), ylim=(-1, 5), size=(500, 250), legend = false)

We can see that we have three states, one for each height of the plot (1, 2, 3). This height is also our emission parameter, so state one produces a value of one, state two produces a value of two, and so on.

Ultimately, we would like to understand three major parameters:

  1. The transition matrix. This is a matrix that assigns a probability of switching from one state to any other state, including the state that we are already in.
  2. The emission parameters, which describes a typical value emitted by some state. In the plot above, the emission parameter for state one is simply one.
  3. The state sequence is our understanding of what state we were actually in when we observed some data. This is very important in more sophisticated HMMs, where the emission value does not equal our state.

With this in mind, let’s set up our model. We are going to use some of our knowledge as modelers to provide additional information about our system. This takes the form of the prior on our emission parameter.

\[ m_i \sim \mathrm{Normal}(i, 0.5) \quad \text{where} \quad m = \{1,2,3\} \]

Simply put, this says that we expect state one to emit values in a Normally distributed manner, where the mean of each state’s emissions is that state’s value. The variance of 0.5 helps the model converge more quickly — consider the case where we have a variance of 1 or 2. In this case, the likelihood of observing a 2 when we are in state 1 is actually quite high, as it is within a standard deviation of the true emission value. Applying the prior that we are likely to be tightly centered around the mean prevents our model from being too confused about the state that is generating our observations.

The priors on our transition matrix are noninformative, using T[i] ~ Dirichlet(ones(K)/K). The Dirichlet prior used in this way assumes that the state is likely to change to any other state with equal probability. As we’ll see, this transition matrix prior will be overwritten as we observe data.

# Turing model definition.
@model function BayesHmm(y, K)
    # Get observation length.
    N = length(y)

    # State sequence.
    s = zeros(Int, N)

    # Emission matrix.
    m = Vector(undef, K)

    # Transition matrix.
    T = Vector{Vector}(undef, K)

    # Assign distributions to each element
    # of the transition matrix and the
    # emission matrix.
    for i in 1:K
        T[i] ~ Dirichlet(ones(K) / K)
        m[i] ~ Normal(i, 0.5)
    end

    # Observe each point of the input.
    s[1] ~ Categorical(K)
    y[1] ~ Normal(m[s[1]], 0.1)

    for i in 2:N
        s[i] ~ Categorical(vec(T[s[i - 1]]))
        y[i] ~ Normal(m[s[i]], 0.1)
    end
end;

We will use a combination of two samplers (HMC and Particle Gibbs) by passing them to the Gibbs sampler. The Gibbs sampler allows for compositional inference, where we can utilize different samplers on different parameters. (For API details of these samplers, please see Turing.jl’s API documentation.)

In this case, we use HMC for m and T, representing the emission and transition matrices respectively. We use the Particle Gibbs sampler for s, the state sequence. You may wonder why it is that we are not assigning s to the HMC sampler, and why it is that we need compositional Gibbs sampling at all.

The parameter s is not a continuous variable. It is a vector of integers, and thus Hamiltonian methods like HMC and NUTS won’t work correctly. Gibbs allows us to apply the right tools to the best effect. If you are a particularly advanced user interested in higher performance, you may benefit from setting up your Gibbs sampler to use different automatic differentiation backends for each parameter space.

Time to run our sampler.

g = Gibbs((:m, :T) => HMC(0.01, 50), :s => PG(120))
chn = sample(BayesHmm(y, 3), g, 1000)
Chains MCMC chain (1000×43×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 615.25 seconds
Compute duration  = 615.25 seconds
parameters        = T[1][1], T[1][2], T[1][3], m[1], T[2][1], T[2][2], T[2][3], m[2], T[3][1], T[3][2], T[3][3], m[3], s[1], s[2], s[3], s[4], s[5], s[6], s[7], s[8], s[9], s[10], s[11], s[12], s[13], s[14], s[15], s[16], s[17], s[18], s[19], s[20], s[21], s[22], s[23], s[24], s[25], s[26], s[27], s[28], s[29], s[30]
internals         = lp

Summary Statistics
  parameters      mean       std      mcse    ess_bulk   ess_tail      rhat       Symbol   Float64   Float64   Float64     Float64    Float64   Float64    ⋯

     T[1][1]    0.7435    0.1300    0.0185     52.6831   151.8818    1.0564    ⋯
     T[1][2]    0.1030    0.0798    0.0077     75.9849    54.5749    0.9997    ⋯
     T[1][3]    0.1535    0.1117    0.0260     15.9607    55.0223    1.0833    ⋯
        m[1]    1.9969    0.0631    0.0012   3000.0000    15.7887    1.0788    ⋯
     T[2][1]    0.1182    0.0876    0.0105     54.4099    85.1522    1.0203    ⋯
     T[2][2]    0.8527    0.0965    0.0128     44.2596    94.9147    1.0318    ⋯
     T[2][3]    0.0291    0.0421    0.0060     20.6365    36.6645    0.9994    ⋯
        m[2]    1.0056    0.0366    0.0037     85.0673   245.6233    1.0265    ⋯
     T[3][1]    0.1457    0.1313    0.0155     54.3377   148.5093    1.0130    ⋯
     T[3][2]    0.0509    0.0775    0.0227      4.5655    21.4848    1.3183    ⋯
     T[3][3]    0.8034    0.1467    0.0170     67.4371    95.7819    1.0150    ⋯
        m[3]    2.9989    0.0607    0.0026    198.3079   334.7791    1.0081    ⋯
        s[1]    2.0000    0.0000       NaN         NaN        NaN       NaN    ⋯
        s[2]    2.0000    0.0000       NaN         NaN        NaN       NaN    ⋯
        s[3]    2.0000    0.0000       NaN         NaN        NaN       NaN    ⋯
        s[4]    2.0000    0.0000       NaN         NaN        NaN       NaN    ⋯
        s[5]    2.0000    0.0000       NaN         NaN        NaN       NaN    ⋯
      ⋮           ⋮         ⋮         ⋮          ⋮          ⋮          ⋮       ⋱
                                                    1 column and 25 rows omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

     T[1][1]    0.4497    0.6666    0.7601    0.8412    0.9470
     T[1][2]    0.0056    0.0440    0.0842    0.1408    0.2957
     T[1][3]    0.0077    0.0643    0.1342    0.2129    0.4323
        m[1]    1.8893    1.9737    1.9963    2.0184    2.1069
     T[2][1]    0.0087    0.0492    0.0961    0.1684    0.3341
     T[2][2]    0.6315    0.7964    0.8728    0.9279    0.9795
     T[2][3]    0.0001    0.0035    0.0139    0.0347    0.1523
        m[2]    0.9504    0.9867    1.0037    1.0226    1.0596
     T[3][1]    0.0076    0.0465    0.1057    0.2114    0.4789
     T[3][2]    0.0000    0.0016    0.0185    0.0610    0.2929
     T[3][3]    0.4758    0.7260    0.8454    0.9163    0.9765
        m[3]    2.9305    2.9740    2.9992    3.0253    3.0670
        s[1]    2.0000    2.0000    2.0000    2.0000    2.0000
        s[2]    2.0000    2.0000    2.0000    2.0000    2.0000
        s[3]    2.0000    2.0000    2.0000    2.0000    2.0000
        s[4]    2.0000    2.0000    2.0000    2.0000    2.0000
        s[5]    2.0000    2.0000    2.0000    2.0000    2.0000
      ⋮           ⋮         ⋮         ⋮         ⋮         ⋮
                                                 25 rows omitted

Let’s see how well our chain performed. Ordinarily, using display(chn) would be a good first step, but we have generated a lot of parameters here (s[1], s[2], m[1], and so on). It’s a bit easier to show how our model performed graphically.

The code below generates an animation showing the graph of the data above, and the data our model generates in each sample.

# Extract our m and s parameters from the chain.
m_set = MCMCChains.group(chn, :m).value
s_set = MCMCChains.group(chn, :s).value

# Iterate through the MCMC samples.
Ns = 1:length(chn)

# Make an animation.
animation = @gif for i in Ns
    m = m_set[i, :]
    s = Int.(s_set[i, :])
    emissions = m[s]

    p = plot(
        y;
        chn=:red,
        size=(500, 250),
        xlabel="Time",
        ylabel="State",
        legend=:topright,
        label="True data",
        xlim=(0, 30),
        ylim=(-1, 5),
    )
    plot!(emissions; color=:blue, label="Sample $i")
end every 3
[ Info: Saved animation to /tmp/jl_efYV4TSBCP.gif

Looks like our model did a pretty good job, but we should also check to make sure our chain converges. A quick check is to examine whether the diagonal (representing the probability of remaining in the current state) of the transition matrix appears to be stationary. The code below extracts the diagonal and shows a traceplot of each persistence probability.

# Index the chain with the persistence probabilities.
subchain = chn[["T[1][1]", "T[2][2]", "T[3][3]"]]

plot(subchain; seriestype=:traceplot, title="Persistence Probability", legend=false)

A cursory examination of the traceplot above indicates that all three chains converged to something resembling stationary. We can use the diagnostic functions provided by MCMCChains to engage in some more formal tests, like the Heidelberg and Welch diagnostic:

heideldiag(MCMCChains.group(chn, :T))[1]
Heidelberger and Welch diagnostic - Chain 1
  parameters     burnin   stationarity    pvalue      mean   halfwidth     tes ⋯
      Symbol      Int64           Bool   Float64   Float64     Float64     Boo ⋯

     T[1][1]   100.0000         1.0000    0.2667    0.7583      0.0223   1.000 ⋯
     T[1][2]     0.0000         1.0000    0.7908    0.1030      0.0152   0.000 ⋯
     T[1][3]   100.0000         1.0000    0.1364    0.1398      0.0245   0.000 ⋯
     T[2][1]     0.0000         1.0000    0.6433    0.1182      0.0204   0.000 ⋯
     T[2][2]     0.0000         1.0000    0.5017    0.8527      0.0251   1.000 ⋯
     T[2][3]     0.0000         1.0000    0.7079    0.0291      0.0122   0.000 ⋯
     T[3][1]     0.0000         1.0000    0.6959    0.1457      0.0309   0.000 ⋯
     T[3][2]   400.0000         1.0000    0.2112    0.0354      0.0248   0.000 ⋯
     T[3][3]     0.0000         1.0000    0.2364    0.8034      0.0333   1.000 ⋯
                                                                1 column omitted

The p-values on the test suggest that we cannot reject the hypothesis that the observed sequence comes from a stationary distribution, so we can be reasonably confident that our transition matrix has converged to something reasonable.

Efficient Inference With The Forward Algorithm

While the above method works well for the simple example in this tutorial, some users may desire a more efficient method, especially when their model is more complicated. One simple way to improve inference is to marginalize out the hidden states of the model with an appropriate algorithm, calculating only the posterior over the continuous random variables. Not only does this allow more efficient inference via Rao-Blackwellization, but now we can sample our model with NUTS() alone, which is usually a much more performant MCMC kernel.

Thankfully, HiddenMarkovModels.jl provides an extremely efficient implementation of many algorithms related to hidden Markov models. This allows us to rewrite our model as:

using HiddenMarkovModels
using FillArrays
using LinearAlgebra
using LogExpFunctions


@model function BayesHmm2(y, K)
    m ~ Bijectors.ordered(MvNormal([1.0, 2.0, 3.0], 0.5I))
    T ~ filldist(Dirichlet(fill(1/K, K)), K)

    hmm = HMM(softmax(ones(K)), copy(T'), [Normal(m[i], 0.1) for i in 1:K])
    Turing.@addlogprob! logdensityof(hmm, y)
end

chn2 = sample(BayesHmm2(y, 3), NUTS(), 1000)
Info: Found initial step size
  ϵ = 0.0125
Chains MCMC chain (1000×24×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 7.4 seconds
Compute duration  = 7.4 seconds
parameters        = m[1], m[2], m[3], T[1, 1], T[2, 1], T[3, 1], T[1, 2], T[2, 2], T[3, 2], T[1, 3], T[2, 3], T[3, 3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk   ess_tail      rhat       Symbol   Float64   Float64   Float64     Float64    Float64   Float64    ⋯

        m[1]    0.9987    0.0275    0.0010    683.8634   656.1835    1.0034    ⋯
        m[2]    2.0002    0.0331    0.0011    915.3686   719.6985    1.0002    ⋯
        m[3]    2.9992    0.0387    0.0012    953.8421   838.3376    0.9998    ⋯
     T[1, 1]    0.8674    0.0906    0.0029    924.7353   504.3042    1.0006    ⋯
     T[2, 1]    0.1041    0.0802    0.0025    977.7582   741.3550    1.0009    ⋯
     T[3, 1]    0.0285    0.0491    0.0017    889.3872   531.6543    0.9996    ⋯
     T[1, 2]    0.1208    0.0917    0.0024   1168.5084   659.1198    1.0003    ⋯
     T[2, 2]    0.7575    0.1236    0.0030   1446.9106   788.3338    0.9997    ⋯
     T[3, 2]    0.1217    0.0950    0.0026    988.9190   452.7226    0.9993    ⋯
     T[1, 3]    0.0430    0.0723    0.0024    415.9807   427.5867    0.9995    ⋯
     T[2, 3]    0.1737    0.1371    0.0040    879.9310   327.0175    0.9995    ⋯
     T[3, 3]    0.7833    0.1494    0.0044   1073.8259   611.4824    0.9993    ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

        m[1]    0.9450    0.9804    0.9996    1.0174    1.0537
        m[2]    1.9358    1.9785    1.9996    2.0222    2.0669
        m[3]    2.9235    2.9726    2.9987    3.0263    3.0744
     T[1, 1]    0.6537    0.8139    0.8869    0.9393    0.9846
     T[2, 1]    0.0074    0.0405    0.0870    0.1470    0.2990
     T[3, 1]    0.0000    0.0009    0.0082    0.0338    0.1683
     T[1, 2]    0.0087    0.0516    0.1015    0.1646    0.3432
     T[2, 2]    0.4765    0.6792    0.7723    0.8485    0.9499
     T[3, 2]    0.0064    0.0489    0.0995    0.1667    0.3574
     T[1, 3]    0.0000    0.0011    0.0110    0.0532    0.2862
     T[2, 3]    0.0081    0.0701    0.1429    0.2471    0.5190
     T[3, 3]    0.4281    0.6936    0.8161    0.9009    0.9808

We can compare the chains of these two models, confirming the posterior estimate is similar (modulo label switching concerns with the Gibbs model):

Plotting Chains
plot(chn["m[1]"], label = "m[1], Model 1, Gibbs", color = :lightblue)
plot!(chn2["m[1]"], label = "m[1], Model 2, NUTS", color = :blue)
plot!(chn["m[2]"], label = "m[2], Model 1, Gibbs", color = :pink)
plot!(chn2["m[2]"], label = "m[2], Model 2, NUTS", color = :red)
plot!(chn["m[3]"], label = "m[3], Model 1, Gibbs", color = :yellow)
plot!(chn2["m[3]"], label = "m[3], Model 2, NUTS", color = :orange)

Recovering Marginalized Trajectories

We can use the viterbi() algorithm, also from the HiddenMarkovModels package, to recover the most probable state for each parameter set in our posterior sample:

@model function BayesHmmRecover(y, K, IncludeGenerated = false)
    m ~ Bijectors.ordered(MvNormal([1.0, 2.0, 3.0], 0.5I))
    T ~ filldist(Dirichlet(fill(1/K, K)), K)

    hmm = HMM(softmax(ones(K)), copy(T'), [Normal(m[i], 0.1) for i in 1:K])
    Turing.@addlogprob! logdensityof(hmm, y)

    # Conditional generation of the hidden states.
    if IncludeGenerated
        seq, _ = viterbi(hmm, y)
        s := [m[s] for s in seq]
    end
end

chn_recover = sample(BayesHmmRecover(y, 3, true), NUTS(), 1000)
Info: Found initial step size
  ϵ = 0.00885009765625
Chains MCMC chain (1000×54×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 7.2 seconds
Compute duration  = 7.2 seconds
parameters        = m[1], m[2], m[3], T[1, 1], T[2, 1], T[3, 1], T[1, 2], T[2, 2], T[3, 2], T[1, 3], T[2, 3], T[3, 3], s[1], s[2], s[3], s[4], s[5], s[6], s[7], s[8], s[9], s[10], s[11], s[12], s[13], s[14], s[15], s[16], s[17], s[18], s[19], s[20], s[21], s[22], s[23], s[24], s[25], s[26], s[27], s[28], s[29], s[30]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk   ess_tail      rhat       Symbol   Float64   Float64   Float64     Float64    Float64   Float64    ⋯

        m[1]    1.0003    0.0287    0.0013    481.9269   246.8267    1.0005    ⋯
        m[2]    1.9992    0.0315    0.0012    687.9035   787.2228    0.9994    ⋯
        m[3]    3.0015    0.0371    0.0015    612.8974   726.1939    1.0004    ⋯
     T[1, 1]    0.8715    0.0847    0.0028    869.3594   652.3149    1.0029    ⋯
     T[2, 1]    0.1019    0.0771    0.0027    720.0682   792.7068    1.0027    ⋯
     T[3, 1]    0.0267    0.0457    0.0017    485.3129   257.4259    1.0003    ⋯
     T[1, 2]    0.1247    0.0965    0.0025   1242.5887   423.8811    0.9994    ⋯
     T[2, 2]    0.7558    0.1250    0.0034   1404.8057   625.3366    0.9992    ⋯
     T[3, 2]    0.1195    0.0927    0.0027   1129.8165   743.7114    1.0000    ⋯
     T[1, 3]    0.0442    0.0706    0.0032     58.2980    34.8222    1.0005    ⋯
     T[2, 3]    0.1694    0.1260    0.0038    964.4918   496.6443    1.0005    ⋯
     T[3, 3]    0.7864    0.1394    0.0045    914.0145   553.0357    1.0003    ⋯
        s[1]    1.0003    0.0287    0.0013    481.9269   246.8267    1.0005    ⋯
        s[2]    1.0003    0.0287    0.0013    481.9269   246.8267    1.0005    ⋯
        s[3]    1.0003    0.0287    0.0013    481.9269   246.8267    1.0005    ⋯
        s[4]    1.0003    0.0287    0.0013    481.9269   246.8267    1.0005    ⋯
        s[5]    1.0003    0.0287    0.0013    481.9269   246.8267    1.0005    ⋯
      ⋮           ⋮         ⋮         ⋮          ⋮          ⋮          ⋮       ⋱
                                                    1 column and 25 rows omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

        m[1]    0.9412    0.9801    0.9995    1.0188    1.0538
        m[2]    1.9373    1.9779    1.9981    2.0203    2.0580
        m[3]    2.9332    2.9758    3.0008    3.0260    3.0760
     T[1, 1]    0.6690    0.8240    0.8839    0.9398    0.9830
     T[2, 1]    0.0080    0.0426    0.0848    0.1403    0.2921
     T[3, 1]    0.0000    0.0010    0.0077    0.0295    0.1525
     T[1, 2]    0.0095    0.0491    0.1028    0.1810    0.3631
     T[2, 2]    0.4889    0.6730    0.7721    0.8547    0.9514
     T[3, 2]    0.0088    0.0486    0.0965    0.1693    0.3548
     T[1, 3]    0.0000    0.0009    0.0128    0.0582    0.2459
     T[2, 3]    0.0093    0.0717    0.1386    0.2435    0.4717
     T[3, 3]    0.4582    0.7024    0.8147    0.8948    0.9780
        s[1]    0.9412    0.9801    0.9995    1.0188    1.0538
        s[2]    0.9412    0.9801    0.9995    1.0188    1.0538
        s[3]    0.9412    0.9801    0.9995    1.0188    1.0538
        s[4]    0.9412    0.9801    0.9995    1.0188    1.0538
        s[5]    0.9412    0.9801    0.9995    1.0188    1.0538
      ⋮           ⋮         ⋮         ⋮         ⋮         ⋮
                                                 25 rows omitted

Plotting the estimated states, we can see that the results align well with our expectations:

p = plot(xlim=(0, 30), ylim=(-1, 5), size=(500, 250))
for i in 1:100
    ind = rand(DiscreteUniform(1, 1000))
    plot!(MCMCChains.group(chn_recover, :s).value[ind,:], color = :grey, opacity = 0.1, legend = :false)
end
scatter!(y, color = :blue)

p
Back to top