Stochastic Gradient Samplers

Turing.jl provides stochastic gradient-based MCMC samplers: Stochastic Gradient Langevin Dynamics (SGLD) and Stochastic Gradient Hamiltonian Monte Carlo (SGHMC).

Current Capabilities

The current implementation in Turing.jl is primarily useful for: - Research purposes: Studying stochastic gradient MCMC methods - Streaming data: When data arrives continuously - Experimental applications: Testing stochastic sampling approaches

Important: The current implementation computes full gradients with added stochastic noise rather than true mini-batch stochastic gradients. This means these samplers don’t currently provide the computational benefits typically associated with stochastic gradient methods for large datasets. They require very careful hyperparameter tuning and often perform slower than standard samplers like HMC or NUTS for most practical applications.

Setup

using Turing
using Distributions
using StatsPlots
using Random
using LinearAlgebra

Random.seed!(123)

# Disable progress bars for cleaner output
Turing.setprogress!(false)
[ Info: [Turing]: progress logging is disabled globally
false

SGLD (Stochastic Gradient Langevin Dynamics)

SGLD adds properly scaled noise to gradient descent steps to enable MCMC sampling. The key insight is that the right amount of noise transforms optimization into sampling from the posterior distribution.

Let’s start with a simple Gaussian model:

# Generate synthetic data
true_μ = 2.0
true_σ = 1.5
N = 100
data = rand(Normal(true_μ, true_σ), N)

# Define a simple Gaussian model
@model function gaussian_model(x)
    μ ~ Normal(0, 10)
    σ ~ truncated(Normal(0, 5); lower=0)
    
    for i in 1:length(x)
        x[i] ~ Normal(μ, σ)
    end
end

model = gaussian_model(data)
DynamicPPL.Model{typeof(gaussian_model), (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}, DynamicPPL.DefaultContext}(gaussian_model, (x = [3.21243189269745, 0.31689123782873996, 0.3430458465060562, 1.3745110472525999, 2.431381970935784, 2.3447280470778016, 1.367347003400461, -0.033385931651795486, 2.1041887116378404, 1.82401579320378  …  1.9524391544746311, 2.6525022164657783, 3.260442669093176, 2.5685766805908874, 1.837062568926331, 0.9818378577483255, 0.3676722290955695, 3.0556374886884523, 2.2149888398562707, 2.222563050014123],), NamedTuple(), DynamicPPL.DefaultContext())

SGLD requires very small step sizes to ensure stability. We use a PolynomialStepsize that decreases over time. Note: Currently, PolynomialStepsize is the primary stepsize schedule available in Turing for SGLD:

# SGLD with polynomial stepsize schedule
# stepsize(t) = a / (b + t)^γ
sgld_stepsize = Turing.PolynomialStepsize(0.0001, 10000, 0.55)
chain_sgld = sample(model, SGLD(stepsize=sgld_stepsize), 5000)

summarystats(chain_sgld)
Summary Statistics
  parameters       mean       std      mcse   ess_bulk   ess_tail      rhat       Symbol    Float64   Float64   Float64    Float64    Float64   Float64    ⋯

           μ   -17.0516    0.0458    0.0143    10.8065    19.3314    2.1228    ⋯
           σ     6.5796    2.0641    0.6402    10.5632    18.3607    2.1230    ⋯
                                                                1 column omitted
plot(chain_sgld)

SGHMC (Stochastic Gradient Hamiltonian Monte Carlo)

SGHMC extends HMC to the stochastic gradient setting by incorporating friction to counteract the noise from stochastic gradients:

# SGHMC with very small learning rate
chain_sghmc = sample(model, SGHMC(learning_rate=0.00001, momentum_decay=0.1), 5000)

summarystats(chain_sghmc)
Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           μ    2.0993    0.5436    0.1510    13.5879    18.3607    1.1960     ⋯
           σ    2.5895   16.1407    1.2293    20.7905    18.4349    1.1225     ⋯
                                                                1 column omitted
plot(chain_sghmc)

Comparison with Standard HMC

For comparison, let’s sample the same model using standard HMC:

chain_hmc = sample(model, HMC(0.01, 10), 1000)

println("True values: μ = ", true_μ, ", σ = ", true_σ)
summarystats(chain_hmc)
True values: μ = 2.0, σ = 1.5
Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           μ   -5.9663    0.0000    0.0000        NaN        NaN       NaN     ⋯
           σ    0.2069    0.0000    0.0000        NaN        NaN       NaN     ⋯
                                                                1 column omitted

Compare the trace plots to see how the different samplers explore the posterior:

p1 = plot(chain_sgld[:μ], label="SGLD", title="μ parameter traces")
hline!([true_μ], label="True value", linestyle=:dash, color=:red)

p2 = plot(chain_sghmc[:μ], label="SGHMC")
hline!([true_μ], label="True value", linestyle=:dash, color=:red)

p3 = plot(chain_hmc[:μ], label="HMC")
hline!([true_μ], label="True value", linestyle=:dash, color=:red)

plot(p1, p2, p3, layout=(3,1), size=(800,600))

The comparison shows that: - SGLD exhibits slower convergence and higher variance due to the injected noise, requiring longer chains to achieve stable estimates - SGHMC shows slightly better mixing than SGLD due to the momentum term, but still requires careful tuning - HMC converges quickly and efficiently explores the posterior, demonstrating why it’s preferred for small to medium-sized problems

Bayesian Linear Regression Example

Here’s a more complex example using Bayesian linear regression:

# Generate regression data
n_features = 3
n_samples = 100
X = randn(n_samples, n_features)
true_β = [0.5, -1.2, 2.1]
true_σ_noise = 0.3
y = X * true_β + true_σ_noise * randn(n_samples)

@model function linear_regression(X, y)
    n_features = size(X, 2)
    
    # Priors
    β ~ MvNormal(zeros(n_features), 3 * I)
    σ ~ truncated(Normal(0, 1); lower=0)
    
    # Likelihood
    y ~ MvNormal(X * β, σ^2 * I)
end

lr_model = linear_regression(X, y)
DynamicPPL.Model{typeof(linear_regression), (:X, :y), (), (), Tuple{Matrix{Float64}, Vector{Float64}}, Tuple{}, DynamicPPL.DefaultContext}(linear_regression, (X = [-0.08993884887496832 1.2694180094557772 -0.45068406344161077; -0.23528025045836815 -1.0348870573833149 -1.2512585407119565; … ; -0.5815563239702138 -0.19790550383157401 -0.7201291845682822; 0.29678442882680006 0.6426754256642815 -0.8729317283503407], y = [-2.5001125493734633, -1.3582233483639436, -3.8825717018806856, -0.2345200635330288, 1.4937176261849854, 2.8659122069995644, -0.5833355856450775, 4.642283548210101, 0.14909888834210028, 1.3335900592696839  …  5.9741160301704, -1.5777125963436005, 3.9896734979440236, -1.0204264890982526, -1.6606828145645047, 1.76720805427176, -0.20620159329470383, -1.9121131995245513, -0.9431065705584871, -2.3648743995748114]), NamedTuple(), DynamicPPL.DefaultContext())

Sample using the stochastic gradient methods:

# Very conservative parameters for stability
sgld_lr_stepsize = Turing.PolynomialStepsize(0.00005, 10000, 0.55)
chain_lr_sgld = sample(lr_model, SGLD(stepsize=sgld_lr_stepsize), 5000)

chain_lr_sghmc = sample(lr_model, SGHMC(learning_rate=0.00005, momentum_decay=0.1), 5000)

chain_lr_hmc = sample(lr_model, HMC(0.01, 10), 1000)
Chains MCMC chain (1000×14×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 1.52 seconds
Compute duration  = 1.52 seconds
parameters        = β[1], β[2], β[3], σ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

        β[1]    0.4978    0.1697    0.0502    41.6343    12.4475    1.0211     ⋯
        β[2]   -1.1118    0.2820    0.0775    47.3436    12.2657    1.0143     ⋯
        β[3]    1.7880    1.4977    0.4645    38.1562    12.2549    1.0595     ⋯
           σ    0.6152    1.1602    0.3603    39.1620    12.2561    1.0191     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

        β[1]    0.4014    0.4409    0.4600    0.4798    1.1463
        β[2]   -1.2268   -1.1908   -1.1716   -1.1504    0.0196
        β[3]   -4.5180    2.1310    2.1603    2.1863    2.2384
           σ    0.2811    0.3083    0.3248    0.3442    5.6003

Compare the results to evaluate the performance of stochastic gradient samplers on a more complex model:

println("True β values: ", true_β)
println("True σ value: ", true_σ_noise)
println()

println("SGLD estimates:")
summarystats(chain_lr_sgld)
True β values: [0.5, -1.2, 2.1]
True σ value: 0.3

SGLD estimates:
Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

        β[1]    1.4193    0.0092    0.0026    13.7923    33.1792    1.1167     ⋯
        β[2]   -0.1851    0.0195    0.0060    11.5311    25.2985    1.4079     ⋯
        β[3]   -0.1172    0.0220    0.0068    10.8609    20.7479    1.9659     ⋯
           σ    1.4488    0.0931    0.0291    10.6692    19.4335    2.0535     ⋯
                                                                1 column omitted

The linear regression example demonstrates that stochastic gradient samplers can recover the true parameters, but: - They require significantly longer chains (5000 vs 1000 for HMC) - The estimates may have higher variance - Convergence diagnostics should be carefully examined before trusting the results

Automatic Differentiation Backends

Both samplers support different AD backends. For more information about automatic differentiation in Turing, see the Automatic Differentiation documentation.

using ADTypes

# ForwardDiff (default) - good for few parameters
sgld_forward = SGLD(stepsize=sgld_stepsize, adtype=AutoForwardDiff())

# ReverseDiff - better for many parameters  
sgld_reverse = SGLD(stepsize=sgld_stepsize, adtype=AutoReverseDiff())

# Zygote - good for complex models
sgld_zygote = SGLD(stepsize=sgld_stepsize, adtype=AutoZygote())
SGLD{AutoZygote, PolynomialStepsize{Float64}}(PolynomialStepsize{Float64}(0.0001, 10000.0, 0.55), AutoZygote())

Best Practices and Recommendations

When to Consider Stochastic Gradient Samplers

  • Streaming data: When data arrives continuously and you need online inference
  • Research: For studying stochastic gradient MCMC methods
  • Educational purposes: For understanding stochastic gradient MCMC algorithms

Critical Hyperparameters

For SGLD: - Use PolynomialStepsize with very small initial values (≤ 0.0001) - Larger b values in PolynomialStepsize(a, b, γ) provide more stability - The stepsize decreases as a / (b + t)^γ

For SGHMC: - Use extremely small learning rates (≤ 0.00001) - Momentum decay (friction) typically between 0.1-0.5 - Higher momentum decay improves stability but slows convergence

Current Limitations

  1. No mini-batching: Full gradients are computed despite “stochastic” name
  2. Hyperparameter sensitivity: Requires extensive tuning
  3. Computational overhead: Often slower than HMC/NUTS for small-medium datasets
  4. Convergence: Typically requires longer chains

General Recommendations

  • Start conservatively: Use very small step sizes initially
  • Monitor convergence: Check trace plots and diagnostics carefully
  • Compare with HMC/NUTS: Validate results when possible
  • Consider alternatives: For most applications, HMC or NUTS will be more efficient

Summary

Stochastic gradient samplers in Turing.jl provide an interface to gradient-based MCMC methods with added stochasticity. While designed for large-scale problems, the current implementation uses full gradients, making them primarily useful for research or specialized applications. For most practical Bayesian inference tasks, standard samplers like HMC or NUTS will be more efficient and easier to tune.

Back to top