using Turing
using AbstractGPs
using FillArrays
using LaTeXStrings
using Plots
using RDatasets
using ReverseDiff
using StatsBase
using Mooncake
using LinearAlgebra
using Random
Random.seed!(1789);Gaussian Process Latent Variable Models
In a previous tutorial, we have discussed latent variable models, in particular probabilistic principal component analysis (pPCA). Here, we show how we can extend the mapping provided by pPCA to non-linear mappings between input and output. For more details about the Gaussian Process Latent Variable Model (GPLVM), we refer the reader to the original publication and a further extension.
In short, the GPVLM is a dimensionality reduction technique that allows us to embed a high-dimensional dataset in a lower-dimensional embedding. Importantly, it provides the advantage that the linear mappings from the embedded space can be non-linearised through the use of Gaussian Processes.
Setup
Let’s start by loading some dependencies.
We demonstrate the GPLVM with a very small dataset: Fisher’s Iris data set. This is mostly for reasons of run time, so the tutorial can be run quickly. As you will see, one of the major drawbacks of using GPs is their speed, although this is an active area of research. We will briefly touch on some ways to speed things up at the end of this tutorial. We transform the original data with non-linear operations in order to demonstrate the power of GPs to work on non-linear relationships, while keeping the problem reasonably small.
data = dataset("datasets", "iris")
species = data[!, "Species"]
index = shuffle(1:150)
# we extract the four measured quantities,
# so the dimension of the data is only d=4 for this toy example
dat = Matrix(data[index, 1:4])
labels = data[index, "Species"]
# non-linearize data to demonstrate ability of GPs to deal with non-linearity
dat[:, 1] = 0.5 * dat[:, 1] .^ 2 + 0.1 * dat[:, 1] .^ 3
dat[:, 2] = dat[:, 2] .^ 3 + 0.2 * dat[:, 2] .^ 4
dat[:, 3] = 0.1 * exp.(dat[:, 3]) - 0.2 * dat[:, 3] .^ 2
dat[:, 4] = 0.5 * log.(dat[:, 4]) .^ 2 + 0.01 * dat[:, 3] .^ 5
# normalise data
dt = fit(ZScoreTransform, dat; dims=1);
StatsBase.transform!(dt, dat);We will start out by demonstrating the basic similarity between pPCA (see the tutorial on this topic) and the GPLVM model. Indeed, pPCA is basically equivalent to running the GPLVM model with an automatic relevance determination (ARD) linear kernel.
First, we re-introduce the pPCA model (see the tutorial on pPCA for details)
@model function pPCA(x)
# Dimensionality of the problem.
N, D = size(x)
# latent variable z
z ~ filldist(Normal(), D, N)
# weights/loadings W
w ~ filldist(Normal(), D, D)
mu = (w * z)'
for d in 1:D
x[:, d] ~ MvNormal(mu[:, d], I)
end
return nothing
end;We define two different kernels, a simple linear kernel with an Automatic Relevance Determination transform and a squared exponential kernel.
linear_kernel(α) = LinearKernel() ∘ ARDTransform(α)
sekernel(α, σ) = σ * SqExponentialKernel() ∘ ARDTransform(α);And here is the GPLVM model. We create separate models for the two types of kernel.
@model function GPLVM_linear(Y, K)
# Dimensionality of the problem.
N, D = size(Y)
# K is the dimension of the latent space
@assert K <= D
noise = 1e-3
# Priors
α ~ MvLogNormal(MvNormal(Zeros(K), I))
Z ~ filldist(Normal(), K, N)
mu ~ filldist(Normal(), N)
gp = GP(linear_kernel(α))
gpz = gp(ColVecs(Z), noise)
Y ~ filldist(MvNormal(mu, cov(gpz)), D)
return nothing
end;
@model function GPLVM(Y, K)
# Dimensionality of the problem.
N, D = size(Y)
# K is the dimension of the latent space
@assert K <= D
noise = 1e-3
# Priors
α ~ MvLogNormal(MvNormal(Zeros(K), I))
σ ~ LogNormal(0.0, 1.0)
Z ~ filldist(Normal(), K, N)
mu ~ filldist(Normal(), N)
gp = GP(sekernel(α, σ))
gpz = gp(ColVecs(Z), noise)
Y ~ filldist(MvNormal(mu, cov(gpz)), D)
return nothing
end;# Standard GPs don't scale very well in n, so we use a small subsample for
# the purpose of this tutorial.
n_data = 40
# number of features to use from dataset
n_features = 4
# latent dimension for GP case
ndim = 4;ppca = pPCA(dat[1:n_data, 1:n_features])
chain_ppca = sample(ppca, NUTS(; adtype=AutoMooncake()), 1000);Sampling 0%| | ETA: N/A ┌ Info: Found initial step size └ ϵ = 0.4 Sampling 1%|▎ | ETA: 2:33:39 Sampling 1%|▍ | ETA: 1:21:37 Sampling 2%|▋ | ETA: 0:52:57 Sampling 2%|▉ | ETA: 0:40:25 Sampling 3%|█▏ | ETA: 0:31:46 Sampling 3%|█▎ | ETA: 0:26:42 Sampling 4%|█▌ | ETA: 0:22:33 Sampling 4%|█▋ | ETA: 0:19:49 Sampling 5%|█▉ | ETA: 0:17:24 Sampling 5%|██▏ | ETA: 0:15:42 Sampling 6%|██▍ | ETA: 0:14:07 Sampling 6%|██▌ | ETA: 0:12:57 Sampling 7%|██▊ | ETA: 0:11:50 Sampling 7%|███ | ETA: 0:10:59 Sampling 8%|███▏ | ETA: 0:10:09 Sampling 8%|███▍ | ETA: 0:09:31 Sampling 9%|███▋ | ETA: 0:08:52 Sampling 9%|███▊ | ETA: 0:08:22 Sampling 10%|████ | ETA: 0:07:51 Sampling 10%|████▎ | ETA: 0:07:27 Sampling 11%|████▍ | ETA: 0:07:02 Sampling 11%|████▋ | ETA: 0:06:42 Sampling 12%|████▉ | ETA: 0:06:21 Sampling 12%|█████ | ETA: 0:06:04 Sampling 13%|█████▎ | ETA: 0:05:47 Sampling 13%|█████▌ | ETA: 0:05:33 Sampling 14%|█████▋ | ETA: 0:05:18 Sampling 14%|█████▉ | ETA: 0:05:05 Sampling 15%|██████▏ | ETA: 0:04:52 Sampling 15%|██████▎ | ETA: 0:04:42 Sampling 16%|██████▌ | ETA: 0:04:30 Sampling 16%|██████▊ | ETA: 0:04:21 Sampling 17%|███████ | ETA: 0:04:11 Sampling 17%|███████▏ | ETA: 0:04:03 Sampling 18%|███████▍ | ETA: 0:03:54 Sampling 18%|███████▌ | ETA: 0:03:47 Sampling 19%|███████▊ | ETA: 0:03:39 Sampling 19%|████████ | ETA: 0:03:32 Sampling 20%|████████▎ | ETA: 0:03:25 Sampling 20%|████████▍ | ETA: 0:03:19 Sampling 21%|████████▋ | ETA: 0:03:13 Sampling 21%|████████▉ | ETA: 0:03:07 Sampling 22%|█████████ | ETA: 0:03:01 Sampling 22%|█████████▎ | ETA: 0:02:57 Sampling 23%|█████████▌ | ETA: 0:02:51 Sampling 23%|█████████▋ | ETA: 0:02:47 Sampling 24%|█████████▉ | ETA: 0:02:42 Sampling 24%|██████████▏ | ETA: 0:02:38 Sampling 25%|██████████▎ | ETA: 0:02:33 Sampling 25%|██████████▌ | ETA: 0:02:29 Sampling 26%|██████████▊ | ETA: 0:02:25 Sampling 26%|██████████▉ | ETA: 0:02:22 Sampling 27%|███████████▏ | ETA: 0:02:18 Sampling 27%|███████████▍ | ETA: 0:02:15 Sampling 28%|███████████▋ | ETA: 0:02:11 Sampling 28%|███████████▊ | ETA: 0:02:08 Sampling 29%|████████████ | ETA: 0:02:05 Sampling 29%|████████████▏ | ETA: 0:02:02 Sampling 30%|████████████▍ | ETA: 0:01:59 Sampling 30%|████████████▋ | ETA: 0:01:56 Sampling 31%|████████████▉ | ETA: 0:01:53 Sampling 31%|█████████████ | ETA: 0:01:51 Sampling 32%|█████████████▎ | ETA: 0:01:48 Sampling 32%|█████████████▌ | ETA: 0:01:46 Sampling 33%|█████████████▋ | ETA: 0:01:43 Sampling 33%|█████████████▉ | ETA: 0:01:41 Sampling 34%|██████████████▏ | ETA: 0:01:39 Sampling 34%|██████████████▎ | ETA: 0:01:37 Sampling 35%|██████████████▌ | ETA: 0:01:35 Sampling 35%|██████████████▊ | ETA: 0:01:33 Sampling 36%|██████████████▉ | ETA: 0:01:31 Sampling 36%|███████████████▏ | ETA: 0:01:29 Sampling 37%|███████████████▍ | ETA: 0:01:27 Sampling 37%|███████████████▌ | ETA: 0:01:25 Sampling 38%|███████████████▊ | ETA: 0:01:23 Sampling 38%|████████████████ | ETA: 0:01:22 Sampling 39%|████████████████▏ | ETA: 0:01:20 Sampling 39%|████████████████▍ | ETA: 0:01:18 Sampling 40%|████████████████▋ | ETA: 0:01:17 Sampling 40%|████████████████▊ | ETA: 0:01:15 Sampling 41%|█████████████████ | ETA: 0:01:13 Sampling 41%|█████████████████▎ | ETA: 0:01:12 Sampling 42%|█████████████████▌ | ETA: 0:01:11 Sampling 42%|█████████████████▋ | ETA: 0:01:09 Sampling 43%|█████████████████▉ | ETA: 0:01:08 Sampling 43%|██████████████████ | ETA: 0:01:06 Sampling 44%|██████████████████▎ | ETA: 0:01:05 Sampling 44%|██████████████████▌ | ETA: 0:01:04 Sampling 45%|██████████████████▊ | ETA: 0:01:02 Sampling 45%|██████████████████▉ | ETA: 0:01:01 Sampling 46%|███████████████████▏ | ETA: 0:01:00 Sampling 46%|███████████████████▍ | ETA: 0:00:59 Sampling 47%|███████████████████▌ | ETA: 0:00:58 Sampling 47%|███████████████████▊ | ETA: 0:00:57 Sampling 48%|████████████████████ | ETA: 0:00:55 Sampling 48%|████████████████████▏ | ETA: 0:00:54 Sampling 49%|████████████████████▍ | ETA: 0:00:53 Sampling 49%|████████████████████▋ | ETA: 0:00:52 Sampling 50%|████████████████████▊ | ETA: 0:00:51 Sampling 50%|█████████████████████ | ETA: 0:00:50 Sampling 51%|█████████████████████▎ | ETA: 0:00:49 Sampling 51%|█████████████████████▍ | ETA: 0:00:48 Sampling 52%|█████████████████████▋ | ETA: 0:00:47 Sampling 52%|█████████████████████▉ | ETA: 0:00:46 Sampling 53%|██████████████████████▏ | ETA: 0:00:45 Sampling 53%|██████████████████████▎ | ETA: 0:00:45 Sampling 54%|██████████████████████▌ | ETA: 0:00:44 Sampling 54%|██████████████████████▋ | ETA: 0:00:43 Sampling 55%|██████████████████████▉ | ETA: 0:00:42 Sampling 55%|███████████████████████▏ | ETA: 0:00:41 Sampling 56%|███████████████████████▍ | ETA: 0:00:40 Sampling 56%|███████████████████████▌ | ETA: 0:00:39 Sampling 57%|███████████████████████▊ | ETA: 0:00:39 Sampling 57%|████████████████████████ | ETA: 0:00:38 Sampling 58%|████████████████████████▏ | ETA: 0:00:37 Sampling 58%|████████████████████████▍ | ETA: 0:00:36 Sampling 59%|████████████████████████▋ | ETA: 0:00:36 Sampling 59%|████████████████████████▊ | ETA: 0:00:35 Sampling 60%|█████████████████████████ | ETA: 0:00:34 Sampling 60%|█████████████████████████▎ | ETA: 0:00:34 Sampling 61%|█████████████████████████▍ | ETA: 0:00:33 Sampling 61%|█████████████████████████▋ | ETA: 0:00:32 Sampling 62%|█████████████████████████▉ | ETA: 0:00:31 Sampling 62%|██████████████████████████ | ETA: 0:00:31 Sampling 63%|██████████████████████████▎ | ETA: 0:00:30 Sampling 63%|██████████████████████████▌ | ETA: 0:00:30 Sampling 64%|██████████████████████████▋ | ETA: 0:00:29 Sampling 64%|██████████████████████████▉ | ETA: 0:00:28 Sampling 65%|███████████████████████████▏ | ETA: 0:00:28 Sampling 65%|███████████████████████████▎ | ETA: 0:00:27 Sampling 66%|███████████████████████████▌ | ETA: 0:00:26 Sampling 66%|███████████████████████████▊ | ETA: 0:00:26 Sampling 67%|████████████████████████████ | ETA: 0:00:25 Sampling 67%|████████████████████████████▏ | ETA: 0:00:25 Sampling 68%|████████████████████████████▍ | ETA: 0:00:24 Sampling 68%|████████████████████████████▌ | ETA: 0:00:24 Sampling 69%|████████████████████████████▊ | ETA: 0:00:23 Sampling 69%|█████████████████████████████ | ETA: 0:00:23 Sampling 70%|█████████████████████████████▎ | ETA: 0:00:22 Sampling 70%|█████████████████████████████▍ | ETA: 0:00:22 Sampling 71%|█████████████████████████████▋ | ETA: 0:00:21 Sampling 71%|█████████████████████████████▉ | ETA: 0:00:21 Sampling 72%|██████████████████████████████ | ETA: 0:00:20 Sampling 72%|██████████████████████████████▎ | ETA: 0:00:20 Sampling 73%|██████████████████████████████▌ | ETA: 0:00:19 Sampling 73%|██████████████████████████████▋ | ETA: 0:00:19 Sampling 74%|██████████████████████████████▉ | ETA: 0:00:18 Sampling 74%|███████████████████████████████▏ | ETA: 0:00:18 Sampling 75%|███████████████████████████████▎ | ETA: 0:00:17 Sampling 75%|███████████████████████████████▌ | ETA: 0:00:17 Sampling 76%|███████████████████████████████▊ | ETA: 0:00:16 Sampling 76%|███████████████████████████████▉ | ETA: 0:00:16 Sampling 77%|████████████████████████████████▏ | ETA: 0:00:15 Sampling 77%|████████████████████████████████▍ | ETA: 0:00:15 Sampling 78%|████████████████████████████████▋ | ETA: 0:00:15 Sampling 78%|████████████████████████████████▊ | ETA: 0:00:14 Sampling 79%|█████████████████████████████████ | ETA: 0:00:14 Sampling 79%|█████████████████████████████████▏ | ETA: 0:00:13 Sampling 80%|█████████████████████████████████▍ | ETA: 0:00:13 Sampling 80%|█████████████████████████████████▋ | ETA: 0:00:13 Sampling 81%|█████████████████████████████████▉ | ETA: 0:00:12 Sampling 81%|██████████████████████████████████ | ETA: 0:00:12 Sampling 82%|██████████████████████████████████▎ | ETA: 0:00:11 Sampling 82%|██████████████████████████████████▌ | ETA: 0:00:11 Sampling 83%|██████████████████████████████████▋ | ETA: 0:00:11 Sampling 83%|██████████████████████████████████▉ | ETA: 0:00:10 Sampling 84%|███████████████████████████████████▏ | ETA: 0:00:10 Sampling 84%|███████████████████████████████████▎ | ETA: 0:00:10 Sampling 85%|███████████████████████████████████▌ | ETA: 0:00:09 Sampling 85%|███████████████████████████████████▊ | ETA: 0:00:09 Sampling 86%|███████████████████████████████████▉ | ETA: 0:00:09 Sampling 86%|████████████████████████████████████▏ | ETA: 0:00:08 Sampling 87%|████████████████████████████████████▍ | ETA: 0:00:08 Sampling 87%|████████████████████████████████████▌ | ETA: 0:00:08 Sampling 88%|████████████████████████████████████▊ | ETA: 0:00:07 Sampling 88%|█████████████████████████████████████ | ETA: 0:00:07 Sampling 89%|█████████████████████████████████████▏ | ETA: 0:00:07 Sampling 89%|█████████████████████████████████████▍ | ETA: 0:00:06 Sampling 90%|█████████████████████████████████████▋ | ETA: 0:00:06 Sampling 90%|█████████████████████████████████████▊ | ETA: 0:00:06 Sampling 91%|██████████████████████████████████████ | ETA: 0:00:05 Sampling 91%|██████████████████████████████████████▎ | ETA: 0:00:05 Sampling 92%|██████████████████████████████████████▌ | ETA: 0:00:05 Sampling 92%|██████████████████████████████████████▋ | ETA: 0:00:04 Sampling 93%|██████████████████████████████████████▉ | ETA: 0:00:04 Sampling 93%|███████████████████████████████████████ | ETA: 0:00:04 Sampling 94%|███████████████████████████████████████▎ | ETA: 0:00:03 Sampling 94%|███████████████████████████████████████▌ | ETA: 0:00:03 Sampling 95%|███████████████████████████████████████▊ | ETA: 0:00:03 Sampling 95%|███████████████████████████████████████▉ | ETA: 0:00:03 Sampling 96%|████████████████████████████████████████▏ | ETA: 0:00:02 Sampling 96%|████████████████████████████████████████▍ | ETA: 0:00:02 Sampling 97%|████████████████████████████████████████▌ | ETA: 0:00:02 Sampling 97%|████████████████████████████████████████▊ | ETA: 0:00:02 Sampling 98%|█████████████████████████████████████████ | ETA: 0:00:01 Sampling 98%|█████████████████████████████████████████▏| ETA: 0:00:01 Sampling 99%|█████████████████████████████████████████▍| ETA: 0:00:01 Sampling 99%|█████████████████████████████████████████▋| ETA: 0:00:01 Sampling 100%|█████████████████████████████████████████▊| ETA: 0:00:00 Sampling 100%|██████████████████████████████████████████| Time: 0:00:50 Sampling 100%|██████████████████████████████████████████| Time: 0:00:52 ┌ Warning: There were 1 divergent transitions. Consider reparameterising your model or using a smaller step size. For adaptive samplers such as NUTS and HMCDA, consider increasing `target_accept`. └ @ Turing.Inference ~/.julia/packages/Turing/4hMHm/src/mcmc/hmc.jl:483
# we extract the posterior mean estimates of the parameters from the chain
z_mean = mean(chain_ppca[@varname(z)])
scatter(z_mean[1, :], z_mean[2, :]; group=labels[1:n_data], xlabel=L"z_1", ylabel=L"z_2")We can see that the pPCA fails to distinguish the groups. This is due to the non-linearities that we introduced, as without them the groups can be clearly distinguished using pPCA (see the pPCA tutorial).
Let’s try the same with our linear kernel GPLVM model.
gplvm_linear = GPLVM_linear(dat[1:n_data, 1:n_features], ndim)
chain_linear = sample(gplvm_linear, NUTS(; adtype=AutoMooncake()), 500);Sampling 0%| | ETA: N/A ┌ Info: Found initial step size └ ϵ = 0.025 Sampling 1%|▎ | ETA: 2:26:58 Sampling 1%|▌ | ETA: 1:13:14 Sampling 2%|▋ | ETA: 0:48:38 Sampling 2%|▉ | ETA: 0:39:11 Sampling 3%|█▏ | ETA: 0:31:37 Sampling 3%|█▎ | ETA: 0:26:52 Sampling 4%|█▌ | ETA: 0:23:40 Sampling 4%|█▋ | ETA: 0:21:31 Sampling 5%|█▉ | ETA: 0:19:32 Sampling 5%|██▏ | ETA: 0:17:55 Sampling 6%|██▍ | ETA: 0:16:47 Sampling 6%|██▌ | ETA: 0:15:52 Sampling 7%|██▊ | ETA: 0:14:58 Sampling 7%|███ | ETA: 0:14:05 Sampling 8%|███▎ | ETA: 0:13:24 Sampling 8%|███▍ | ETA: 0:13:02 Sampling 9%|███▋ | ETA: 0:12:31 Sampling 9%|███▊ | ETA: 0:11:52 Sampling 10%|████ | ETA: 0:11:21 Sampling 10%|████▎ | ETA: 0:10:56 Sampling 11%|████▍ | ETA: 0:10:30 Sampling 11%|████▋ | ETA: 0:10:05 Sampling 12%|████▉ | ETA: 0:09:43 Sampling 12%|█████ | ETA: 0:09:27 Sampling 13%|█████▎ | ETA: 0:09:05 Sampling 13%|█████▌ | ETA: 0:08:50 Sampling 14%|█████▊ | ETA: 0:08:42 Sampling 14%|█████▉ | ETA: 0:08:31 Sampling 15%|██████▏ | ETA: 0:08:21 Sampling 15%|██████▍ | ETA: 0:08:08 Sampling 16%|██████▌ | ETA: 0:07:57 Sampling 16%|██████▊ | ETA: 0:07:50 Sampling 17%|███████ | ETA: 0:07:41 Sampling 17%|███████▏ | ETA: 0:07:27 Sampling 18%|███████▍ | ETA: 0:07:19 Sampling 18%|███████▌ | ETA: 0:07:12 Sampling 19%|███████▊ | ETA: 0:07:04 Sampling 19%|████████ | ETA: 0:06:55 Sampling 20%|████████▎ | ETA: 0:06:49 Sampling 20%|████████▍ | ETA: 0:06:44 Sampling 21%|████████▋ | ETA: 0:06:37 Sampling 21%|████████▉ | ETA: 0:06:29 Sampling 22%|█████████▏ | ETA: 0:06:23 Sampling 22%|█████████▎ | ETA: 0:06:16 Sampling 23%|█████████▌ | ETA: 0:06:11 Sampling 23%|█████████▊ | ETA: 0:06:08 Sampling 24%|█████████▉ | ETA: 0:06:02 Sampling 24%|██████████▏ | ETA: 0:05:57 Sampling 25%|██████████▎ | ETA: 0:05:54 Sampling 25%|██████████▌ | ETA: 0:05:51 Sampling 26%|██████████▊ | ETA: 0:05:46 Sampling 26%|██████████▉ | ETA: 0:05:41 Sampling 27%|███████████▏ | ETA: 0:05:38 Sampling 27%|███████████▍ | ETA: 0:05:32 Sampling 28%|███████████▋ | ETA: 0:05:26 Sampling 28%|███████████▊ | ETA: 0:05:24 Sampling 29%|████████████ | ETA: 0:05:19 Sampling 29%|████████████▎ | ETA: 0:05:16 Sampling 30%|████████████▍ | ETA: 0:05:10 Sampling 30%|████████████▋ | ETA: 0:05:08 Sampling 31%|████████████▉ | ETA: 0:05:05 Sampling 31%|█████████████ | ETA: 0:05:01 Sampling 32%|█████████████▎ | ETA: 0:04:58 Sampling 32%|█████████████▌ | ETA: 0:04:56 Sampling 33%|█████████████▋ | ETA: 0:04:53 Sampling 33%|█████████████▉ | ETA: 0:04:51 Sampling 34%|██████████████▏ | ETA: 0:04:49 Sampling 34%|██████████████▎ | ETA: 0:04:46 Sampling 35%|██████████████▌ | ETA: 0:04:43 Sampling 35%|██████████████▊ | ETA: 0:04:41 Sampling 36%|███████████████ | ETA: 0:04:37 Sampling 36%|███████████████▏ | ETA: 0:04:35 Sampling 37%|███████████████▍ | ETA: 0:04:31 Sampling 37%|███████████████▋ | ETA: 0:04:28 Sampling 38%|███████████████▊ | ETA: 0:04:25 Sampling 38%|████████████████ | ETA: 0:04:23 Sampling 39%|████████████████▏ | ETA: 0:04:19 Sampling 39%|████████████████▍ | ETA: 0:04:16 Sampling 40%|████████████████▋ | ETA: 0:04:13 Sampling 40%|████████████████▊ | ETA: 0:04:11 Sampling 41%|█████████████████ | ETA: 0:04:09 Sampling 41%|█████████████████▎ | ETA: 0:04:06 Sampling 42%|█████████████████▌ | ETA: 0:04:03 Sampling 42%|█████████████████▋ | ETA: 0:04:01 Sampling 43%|█████████████████▉ | ETA: 0:03:58 Sampling 43%|██████████████████▏ | ETA: 0:03:56 Sampling 44%|██████████████████▎ | ETA: 0:03:53 Sampling 44%|██████████████████▌ | ETA: 0:03:52 Sampling 45%|██████████████████▊ | ETA: 0:03:49 Sampling 45%|██████████████████▉ | ETA: 0:03:45 Sampling 46%|███████████████████▏ | ETA: 0:03:43 Sampling 46%|███████████████████▍ | ETA: 0:03:41 Sampling 47%|███████████████████▌ | ETA: 0:03:39 Sampling 47%|███████████████████▊ | ETA: 0:03:36 Sampling 48%|████████████████████ | ETA: 0:03:33 Sampling 48%|████████████████████▏ | ETA: 0:03:31 Sampling 49%|████████████████████▍ | ETA: 0:03:29 Sampling 49%|████████████████████▋ | ETA: 0:03:28 Sampling 50%|████████████████████▉ | ETA: 0:03:26 Sampling 50%|█████████████████████ | ETA: 0:03:24 Sampling 51%|█████████████████████▎ | ETA: 0:03:22 Sampling 51%|█████████████████████▌ | ETA: 0:03:19 Sampling 52%|█████████████████████▋ | ETA: 0:03:16 Sampling 52%|█████████████████████▉ | ETA: 0:03:15 Sampling 53%|██████████████████████▏ | ETA: 0:03:12 Sampling 53%|██████████████████████▎ | ETA: 0:03:10 Sampling 54%|██████████████████████▌ | ETA: 0:03:08 Sampling 54%|██████████████████████▋ | ETA: 0:03:06 Sampling 55%|██████████████████████▉ | ETA: 0:03:04 Sampling 55%|███████████████████████▏ | ETA: 0:03:02 Sampling 56%|███████████████████████▍ | ETA: 0:03:00 Sampling 56%|███████████████████████▌ | ETA: 0:02:58 Sampling 57%|███████████████████████▊ | ETA: 0:02:55 Sampling 57%|████████████████████████ | ETA: 0:02:53 Sampling 58%|████████████████████████▎ | ETA: 0:02:50 Sampling 58%|████████████████████████▍ | ETA: 0:02:48 Sampling 59%|████████████████████████▋ | ETA: 0:02:46 Sampling 59%|████████████████████████▊ | ETA: 0:02:44 Sampling 60%|█████████████████████████ | ETA: 0:02:41 Sampling 60%|█████████████████████████▎ | ETA: 0:02:39 Sampling 61%|█████████████████████████▍ | ETA: 0:02:37 Sampling 61%|█████████████████████████▋ | ETA: 0:02:35 Sampling 62%|█████████████████████████▉ | ETA: 0:02:32 Sampling 62%|██████████████████████████ | ETA: 0:02:31 Sampling 63%|██████████████████████████▎ | ETA: 0:02:29 Sampling 63%|██████████████████████████▌ | ETA: 0:02:27 Sampling 64%|██████████████████████████▊ | ETA: 0:02:25 Sampling 64%|██████████████████████████▉ | ETA: 0:02:23 Sampling 65%|███████████████████████████▏ | ETA: 0:02:21 Sampling 65%|███████████████████████████▍ | ETA: 0:02:18 Sampling 66%|███████████████████████████▌ | ETA: 0:02:16 Sampling 66%|███████████████████████████▊ | ETA: 0:02:15 Sampling 67%|████████████████████████████ | ETA: 0:02:12 Sampling 67%|████████████████████████████▏ | ETA: 0:02:10 Sampling 68%|████████████████████████████▍ | ETA: 0:02:08 Sampling 68%|████████████████████████████▌ | ETA: 0:02:06 Sampling 69%|████████████████████████████▊ | ETA: 0:02:04 Sampling 69%|█████████████████████████████ | ETA: 0:02:02 Sampling 70%|█████████████████████████████▎ | ETA: 0:02:00 Sampling 70%|█████████████████████████████▍ | ETA: 0:01:58 Sampling 71%|█████████████████████████████▋ | ETA: 0:01:56 Sampling 71%|█████████████████████████████▉ | ETA: 0:01:54 Sampling 72%|██████████████████████████████▏ | ETA: 0:01:51 Sampling 72%|██████████████████████████████▎ | ETA: 0:01:50 Sampling 73%|██████████████████████████████▌ | ETA: 0:01:47 Sampling 73%|██████████████████████████████▊ | ETA: 0:01:45 Sampling 74%|██████████████████████████████▉ | ETA: 0:01:43 Sampling 74%|███████████████████████████████▏ | ETA: 0:01:42 Sampling 75%|███████████████████████████████▎ | ETA: 0:01:39 Sampling 75%|███████████████████████████████▌ | ETA: 0:01:37 Sampling 76%|███████████████████████████████▊ | ETA: 0:01:35 Sampling 76%|███████████████████████████████▉ | ETA: 0:01:33 Sampling 77%|████████████████████████████████▏ | ETA: 0:01:31 Sampling 77%|████████████████████████████████▍ | ETA: 0:01:29 Sampling 78%|████████████████████████████████▋ | ETA: 0:01:27 Sampling 78%|████████████████████████████████▊ | ETA: 0:01:25 Sampling 79%|█████████████████████████████████ | ETA: 0:01:23 Sampling 79%|█████████████████████████████████▎ | ETA: 0:01:21 Sampling 80%|█████████████████████████████████▍ | ETA: 0:01:19 Sampling 80%|█████████████████████████████████▋ | ETA: 0:01:17 Sampling 81%|█████████████████████████████████▉ | ETA: 0:01:15 Sampling 81%|██████████████████████████████████ | ETA: 0:01:13 Sampling 82%|██████████████████████████████████▎ | ETA: 0:01:11 Sampling 82%|██████████████████████████████████▌ | ETA: 0:01:10 Sampling 83%|██████████████████████████████████▋ | ETA: 0:01:07 Sampling 83%|██████████████████████████████████▉ | ETA: 0:01:05 Sampling 84%|███████████████████████████████████▏ | ETA: 0:01:03 Sampling 84%|███████████████████████████████████▎ | ETA: 0:01:02 Sampling 85%|███████████████████████████████████▌ | ETA: 0:01:00 Sampling 85%|███████████████████████████████████▊ | ETA: 0:00:57 Sampling 86%|████████████████████████████████████ | ETA: 0:00:56 Sampling 86%|████████████████████████████████████▏ | ETA: 0:00:54 Sampling 87%|████████████████████████████████████▍ | ETA: 0:00:52 Sampling 87%|████████████████████████████████████▋ | ETA: 0:00:50 Sampling 88%|████████████████████████████████████▊ | ETA: 0:00:48 Sampling 88%|█████████████████████████████████████ | ETA: 0:00:46 Sampling 89%|█████████████████████████████████████▏ | ETA: 0:00:44 Sampling 89%|█████████████████████████████████████▍ | ETA: 0:00:42 Sampling 90%|█████████████████████████████████████▋ | ETA: 0:00:40 Sampling 90%|█████████████████████████████████████▊ | ETA: 0:00:38 Sampling 91%|██████████████████████████████████████ | ETA: 0:00:36 Sampling 91%|██████████████████████████████████████▎ | ETA: 0:00:34 Sampling 92%|██████████████████████████████████████▌ | ETA: 0:00:32 Sampling 92%|██████████████████████████████████████▋ | ETA: 0:00:31 Sampling 93%|██████████████████████████████████████▉ | ETA: 0:00:29 Sampling 93%|███████████████████████████████████████▏ | ETA: 0:00:27 Sampling 94%|███████████████████████████████████████▎ | ETA: 0:00:24 Sampling 94%|███████████████████████████████████████▌ | ETA: 0:00:23 Sampling 95%|███████████████████████████████████████▊ | ETA: 0:00:21 Sampling 95%|███████████████████████████████████████▉ | ETA: 0:00:19 Sampling 96%|████████████████████████████████████████▏ | ETA: 0:00:17 Sampling 96%|████████████████████████████████████████▍ | ETA: 0:00:15 Sampling 97%|████████████████████████████████████████▌ | ETA: 0:00:13 Sampling 97%|████████████████████████████████████████▊ | ETA: 0:00:11 Sampling 98%|█████████████████████████████████████████ | ETA: 0:00:09 Sampling 98%|█████████████████████████████████████████▏| ETA: 0:00:08 Sampling 99%|█████████████████████████████████████████▍| ETA: 0:00:06 Sampling 99%|█████████████████████████████████████████▋| ETA: 0:00:04 Sampling 100%|█████████████████████████████████████████▉| ETA: 0:00:02 Sampling 100%|██████████████████████████████████████████| Time: 0:06:23 Sampling 100%|██████████████████████████████████████████| Time: 0:06:24
# we extract the posterior mean estimates of the parameters from the chain
z_mean = mean(chain_linear[@varname(Z)])
alpha_mean = mean(chain_linear[@varname(α)])
alpha1, alpha2 = partialsortperm(alpha_mean, 1:2; rev=true)
scatter(
z_mean[alpha1, :],
z_mean[alpha2, :];
group=labels[1:n_data],
xlabel=L"z_{\mathrm{ard}_1}",
ylabel=L"z_{\mathrm{ard}_2}",
)We can see that similar to the pPCA case, the linear kernel GPLVM fails to distinguish between the groups.
Finally, we demonstrate that by changing the kernel to a non-linear function, we are able to separate the data again.
gplvm = GPLVM(dat[1:n_data, 1:n_features], ndim)
chain_gplvm = sample(gplvm, NUTS(; adtype=AutoMooncake()), 500);Sampling 0%| | ETA: N/A ┌ Info: Found initial step size └ ϵ = 0.2 Sampling 1%|▎ | ETA: 0:53:43 Sampling 1%|▌ | ETA: 0:27:11 Sampling 2%|▋ | ETA: 0:18:44 Sampling 2%|▉ | ETA: 0:15:08 Sampling 3%|█▏ | ETA: 0:12:22 Sampling 3%|█▎ | ETA: 0:10:31 Sampling 4%|█▌ | ETA: 0:09:11 Sampling 4%|█▋ | ETA: 0:08:21 Sampling 5%|█▉ | ETA: 0:07:26 Sampling 5%|██▏ | ETA: 0:06:48 Sampling 6%|██▍ | ETA: 0:06:23 Sampling 6%|██▌ | ETA: 0:06:00 Sampling 7%|██▊ | ETA: 0:05:39 Sampling 7%|███ | ETA: 0:05:20 Sampling 8%|███▎ | ETA: 0:05:01 Sampling 8%|███▍ | ETA: 0:04:48 Sampling 9%|███▋ | ETA: 0:04:35 Sampling 9%|███▊ | ETA: 0:04:22 Sampling 10%|████ | ETA: 0:04:12 Sampling 10%|████▎ | ETA: 0:04:04 Sampling 11%|████▍ | ETA: 0:03:55 Sampling 11%|████▋ | ETA: 0:03:45 Sampling 12%|████▉ | ETA: 0:03:37 Sampling 12%|█████ | ETA: 0:03:32 Sampling 13%|█████▎ | ETA: 0:03:27 Sampling 13%|█████▌ | ETA: 0:03:21 Sampling 14%|█████▊ | ETA: 0:03:16 Sampling 14%|█████▉ | ETA: 0:03:11 Sampling 15%|██████▏ | ETA: 0:03:06 Sampling 15%|██████▍ | ETA: 0:03:01 Sampling 16%|██████▌ | ETA: 0:02:58 Sampling 16%|██████▊ | ETA: 0:02:56 Sampling 17%|███████ | ETA: 0:02:50 Sampling 17%|███████▏ | ETA: 0:02:46 Sampling 18%|███████▍ | ETA: 0:02:42 Sampling 18%|███████▌ | ETA: 0:02:40 Sampling 19%|███████▊ | ETA: 0:02:37 Sampling 19%|████████ | ETA: 0:02:33 Sampling 20%|████████▎ | ETA: 0:02:29 Sampling 20%|████████▍ | ETA: 0:02:28 Sampling 21%|████████▋ | ETA: 0:02:24 Sampling 21%|████████▉ | ETA: 0:02:22 Sampling 22%|█████████▏ | ETA: 0:02:19 Sampling 22%|█████████▎ | ETA: 0:02:17 Sampling 23%|█████████▌ | ETA: 0:02:14 Sampling 23%|█████████▊ | ETA: 0:02:11 Sampling 24%|█████████▉ | ETA: 0:02:08 Sampling 24%|██████████▏ | ETA: 0:02:06 Sampling 25%|██████████▎ | ETA: 0:02:04 Sampling 25%|██████████▌ | ETA: 0:02:02 Sampling 26%|██████████▊ | ETA: 0:01:59 Sampling 26%|██████████▉ | ETA: 0:01:58 Sampling 27%|███████████▏ | ETA: 0:01:55 Sampling 27%|███████████▍ | ETA: 0:01:53 Sampling 28%|███████████▋ | ETA: 0:01:50 Sampling 28%|███████████▊ | ETA: 0:01:50 Sampling 29%|████████████ | ETA: 0:01:48 Sampling 29%|████████████▎ | ETA: 0:01:46 Sampling 30%|████████████▍ | ETA: 0:01:46 Sampling 30%|████████████▋ | ETA: 0:01:45 Sampling 31%|████████████▉ | ETA: 0:01:44 Sampling 31%|█████████████ | ETA: 0:01:42 Sampling 32%|█████████████▎ | ETA: 0:01:40 Sampling 32%|█████████████▌ | ETA: 0:01:39 Sampling 33%|█████████████▋ | ETA: 0:01:38 Sampling 33%|█████████████▉ | ETA: 0:01:37 Sampling 34%|██████████████▏ | ETA: 0:01:35 Sampling 34%|██████████████▎ | ETA: 0:01:34 Sampling 35%|██████████████▌ | ETA: 0:01:33 Sampling 35%|██████████████▊ | ETA: 0:01:31 Sampling 36%|███████████████ | ETA: 0:01:30 Sampling 36%|███████████████▏ | ETA: 0:01:29 Sampling 37%|███████████████▍ | ETA: 0:01:27 Sampling 37%|███████████████▋ | ETA: 0:01:26 Sampling 38%|███████████████▊ | ETA: 0:01:25 Sampling 38%|████████████████ | ETA: 0:01:24 Sampling 39%|████████████████▏ | ETA: 0:01:22 Sampling 39%|████████████████▍ | ETA: 0:01:21 Sampling 40%|████████████████▋ | ETA: 0:01:20 Sampling 40%|████████████████▊ | ETA: 0:01:19 Sampling 41%|█████████████████ | ETA: 0:01:18 Sampling 41%|█████████████████▎ | ETA: 0:01:17 Sampling 42%|█████████████████▌ | ETA: 0:01:16 Sampling 42%|█████████████████▋ | ETA: 0:01:15 Sampling 43%|█████████████████▉ | ETA: 0:01:14 Sampling 43%|██████████████████▏ | ETA: 0:01:13 Sampling 44%|██████████████████▎ | ETA: 0:01:13 Sampling 44%|██████████████████▌ | ETA: 0:01:12 Sampling 45%|██████████████████▊ | ETA: 0:01:11 Sampling 45%|██████████████████▉ | ETA: 0:01:10 Sampling 46%|███████████████████▏ | ETA: 0:01:09 Sampling 46%|███████████████████▍ | ETA: 0:01:08 Sampling 47%|███████████████████▌ | ETA: 0:01:07 Sampling 47%|███████████████████▊ | ETA: 0:01:06 Sampling 48%|████████████████████ | ETA: 0:01:05 Sampling 48%|████████████████████▏ | ETA: 0:01:05 Sampling 49%|████████████████████▍ | ETA: 0:01:04 Sampling 49%|████████████████████▋ | ETA: 0:01:03 Sampling 50%|████████████████████▉ | ETA: 0:01:02 Sampling 50%|█████████████████████ | ETA: 0:01:01 Sampling 51%|█████████████████████▎ | ETA: 0:01:01 Sampling 51%|█████████████████████▌ | ETA: 0:01:00 Sampling 52%|█████████████████████▋ | ETA: 0:00:59 Sampling 52%|█████████████████████▉ | ETA: 0:00:58 Sampling 53%|██████████████████████▏ | ETA: 0:00:57 Sampling 53%|██████████████████████▎ | ETA: 0:00:56 Sampling 54%|██████████████████████▌ | ETA: 0:00:56 Sampling 54%|██████████████████████▋ | ETA: 0:00:55 Sampling 55%|██████████████████████▉ | ETA: 0:00:54 Sampling 55%|███████████████████████▏ | ETA: 0:00:53 Sampling 56%|███████████████████████▍ | ETA: 0:00:53 Sampling 56%|███████████████████████▌ | ETA: 0:00:52 Sampling 57%|███████████████████████▊ | ETA: 0:00:51 Sampling 57%|████████████████████████ | ETA: 0:00:50 Sampling 58%|████████████████████████▎ | ETA: 0:00:50 Sampling 58%|████████████████████████▍ | ETA: 0:00:49 Sampling 59%|████████████████████████▋ | ETA: 0:00:48 Sampling 59%|████████████████████████▊ | ETA: 0:00:47 Sampling 60%|█████████████████████████ | ETA: 0:00:47 Sampling 60%|█████████████████████████▎ | ETA: 0:00:46 Sampling 61%|█████████████████████████▍ | ETA: 0:00:45 Sampling 61%|█████████████████████████▋ | ETA: 0:00:45 Sampling 62%|█████████████████████████▉ | ETA: 0:00:44 Sampling 62%|██████████████████████████ | ETA: 0:00:44 Sampling 63%|██████████████████████████▎ | ETA: 0:00:43 Sampling 63%|██████████████████████████▌ | ETA: 0:00:42 Sampling 64%|██████████████████████████▊ | ETA: 0:00:41 Sampling 64%|██████████████████████████▉ | ETA: 0:00:41 Sampling 65%|███████████████████████████▏ | ETA: 0:00:40 Sampling 65%|███████████████████████████▍ | ETA: 0:00:40 Sampling 66%|███████████████████████████▌ | ETA: 0:00:39 Sampling 66%|███████████████████████████▊ | ETA: 0:00:38 Sampling 67%|████████████████████████████ | ETA: 0:00:38 Sampling 67%|████████████████████████████▏ | ETA: 0:00:37 Sampling 68%|████████████████████████████▍ | ETA: 0:00:36 Sampling 68%|████████████████████████████▌ | ETA: 0:00:36 Sampling 69%|████████████████████████████▊ | ETA: 0:00:35 Sampling 69%|█████████████████████████████ | ETA: 0:00:34 Sampling 70%|█████████████████████████████▎ | ETA: 0:00:34 Sampling 70%|█████████████████████████████▍ | ETA: 0:00:33 Sampling 71%|█████████████████████████████▋ | ETA: 0:00:33 Sampling 71%|█████████████████████████████▉ | ETA: 0:00:32 Sampling 72%|██████████████████████████████▏ | ETA: 0:00:31 Sampling 72%|██████████████████████████████▎ | ETA: 0:00:31 Sampling 73%|██████████████████████████████▌ | ETA: 0:00:30 Sampling 73%|██████████████████████████████▊ | ETA: 0:00:30 Sampling 74%|██████████████████████████████▉ | ETA: 0:00:29 Sampling 74%|███████████████████████████████▏ | ETA: 0:00:28 Sampling 75%|███████████████████████████████▎ | ETA: 0:00:28 Sampling 75%|███████████████████████████████▌ | ETA: 0:00:27 Sampling 76%|███████████████████████████████▊ | ETA: 0:00:27 Sampling 76%|███████████████████████████████▉ | ETA: 0:00:26 Sampling 77%|████████████████████████████████▏ | ETA: 0:00:25 Sampling 77%|████████████████████████████████▍ | ETA: 0:00:25 Sampling 78%|████████████████████████████████▋ | ETA: 0:00:24 Sampling 78%|████████████████████████████████▊ | ETA: 0:00:24 Sampling 79%|█████████████████████████████████ | ETA: 0:00:23 Sampling 79%|█████████████████████████████████▎ | ETA: 0:00:22 Sampling 80%|█████████████████████████████████▍ | ETA: 0:00:22 Sampling 80%|█████████████████████████████████▋ | ETA: 0:00:21 Sampling 81%|█████████████████████████████████▉ | ETA: 0:00:21 Sampling 81%|██████████████████████████████████ | ETA: 0:00:20 Sampling 82%|██████████████████████████████████▎ | ETA: 0:00:20 Sampling 82%|██████████████████████████████████▌ | ETA: 0:00:19 Sampling 83%|██████████████████████████████████▋ | ETA: 0:00:19 Sampling 83%|██████████████████████████████████▉ | ETA: 0:00:18 Sampling 84%|███████████████████████████████████▏ | ETA: 0:00:17 Sampling 84%|███████████████████████████████████▎ | ETA: 0:00:17 Sampling 85%|███████████████████████████████████▌ | ETA: 0:00:16 Sampling 85%|███████████████████████████████████▊ | ETA: 0:00:16 Sampling 86%|████████████████████████████████████ | ETA: 0:00:15 Sampling 86%|████████████████████████████████████▏ | ETA: 0:00:15 Sampling 87%|████████████████████████████████████▍ | ETA: 0:00:14 Sampling 87%|████████████████████████████████████▋ | ETA: 0:00:13 Sampling 88%|████████████████████████████████████▊ | ETA: 0:00:13 Sampling 88%|█████████████████████████████████████ | ETA: 0:00:12 Sampling 89%|█████████████████████████████████████▏ | ETA: 0:00:12 Sampling 89%|█████████████████████████████████████▍ | ETA: 0:00:11 Sampling 90%|█████████████████████████████████████▋ | ETA: 0:00:11 Sampling 90%|█████████████████████████████████████▊ | ETA: 0:00:10 Sampling 91%|██████████████████████████████████████ | ETA: 0:00:10 Sampling 91%|██████████████████████████████████████▎ | ETA: 0:00:09 Sampling 92%|██████████████████████████████████████▌ | ETA: 0:00:09 Sampling 92%|██████████████████████████████████████▋ | ETA: 0:00:08 Sampling 93%|██████████████████████████████████████▉ | ETA: 0:00:08 Sampling 93%|███████████████████████████████████████▏ | ETA: 0:00:07 Sampling 94%|███████████████████████████████████████▎ | ETA: 0:00:06 Sampling 94%|███████████████████████████████████████▌ | ETA: 0:00:06 Sampling 95%|███████████████████████████████████████▊ | ETA: 0:00:06 Sampling 95%|███████████████████████████████████████▉ | ETA: 0:00:05 Sampling 96%|████████████████████████████████████████▏ | ETA: 0:00:04 Sampling 96%|████████████████████████████████████████▍ | ETA: 0:00:04 Sampling 97%|████████████████████████████████████████▌ | ETA: 0:00:03 Sampling 97%|████████████████████████████████████████▊ | ETA: 0:00:03 Sampling 98%|█████████████████████████████████████████ | ETA: 0:00:02 Sampling 98%|█████████████████████████████████████████▏| ETA: 0:00:02 Sampling 99%|█████████████████████████████████████████▍| ETA: 0:00:01 Sampling 99%|█████████████████████████████████████████▋| ETA: 0:00:01 Sampling 100%|█████████████████████████████████████████▉| ETA: 0:00:00 Sampling 100%|██████████████████████████████████████████| Time: 0:01:39 Sampling 100%|██████████████████████████████████████████| Time: 0:01:40 ┌ Warning: There were 3 divergent transitions. Consider reparameterising your model or using a smaller step size. For adaptive samplers such as NUTS and HMCDA, consider increasing `target_accept`. └ @ Turing.Inference ~/.julia/packages/Turing/4hMHm/src/mcmc/hmc.jl:483
z_mean = mean(chain_gplvm[@varname(Z)])
alpha_mean = mean(chain_gplvm[@varname(α)])
alpha1, alpha2 = partialsortperm(alpha_mean, 1:2; rev=true)
scatter(
z_mean[alpha1, :],
z_mean[alpha2, :];
group=labels[1:n_data],
xlabel=L"z_{\mathrm{ard}_1}",
ylabel=L"z_{\mathrm{ard}_2}",
)let
@assert abs(
mean(z_mean[alpha1, labels[1:n_data] .== "setosa"]) -
mean(z_mean[alpha1, labels[1:n_data] .!= "setosa"]),
) > 0.5
endNow, the split between the two groups is visible again.