Callbacks

AbstractMCMC provides a unified callback API for monitoring and logging MCMC sampling.

Basic Usage

The mcmc_callback function is the main entry point for creating callbacks:

using AbstractMCMC

struct MyModel <: AbstractMCMC.AbstractModel end

struct MySampler <: AbstractMCMC.AbstractSampler end

function AbstractMCMC.step(rng, ::MyModel, ::MySampler, state=nothing; kwargs...)
   # all samples are zero
   return 0.0, state
end

model, sampler = MyModel(), MySampler()

# Simple callback with a function
cb = mcmc_callback() do rng, model, sampler, transition, state, iteration
    println("Iteration: $iteration")
end

chain = sample(model, sampler, 10; callback=cb)

Combining Multiple Callbacks

Pass multiple callbacks to mcmc_callback to combine them:

cb1 = (args...; kwargs...) -> println("Callback 1")
cb2 = (args...; kwargs...) -> println("Callback 2")

cb = mcmc_callback(cb1, cb2)

You can also add callbacks dynamically using BangBang.push!!:

using BangBang

cb = mcmc_callback(cb1)
cb = push!!(cb, cb2)

TensorBoard Logging

TensorBoard logging requires TensorBoardLogger. Statistics collection also requires OnlineStats.

Basic Logging (No Statistics)

using AbstractMCMC
using TensorBoardLogger

logger = TBLogger("runs/experiment1")
cb = mcmc_callback(logger=logger)

chain = sample(model, sampler, 1000; callback=cb)

Logging with Statistics

To collect running statistics (mean, variance, histograms), load OnlineStats and use the stats argument:

using AbstractMCMC
using TensorBoardLogger
using OnlineStats

logger = TBLogger("runs/experiment1")

# Use default statistics (Mean, Variance, KHist)
cb = mcmc_callback(logger=logger, stats=true)

# Or specify custom statistics
cb = mcmc_callback(
    logger=logger,
    stats=(Mean(), Variance(), KHist(50)),
)
Note

If you request statistics without loading OnlineStats, you will get a helpful error: "Statistics collection requires OnlineStats.jl. Please load OnlineStats before enabling statistics."

Stats Processing Options

Control how samples are processed before computing statistics with stats_options:

cb = mcmc_callback(
    logger=logger,
    stats=true,
    stats_options=(
        skip=100,    # Skip first 100 samples (burn-in)
        thin=5,      # Use every 5th sample
        window=1000, # Rolling window of 1000 samples
    ),
)

Options merge with defaults, so you only need to specify what you want to change:

# Only change thin, skip and window use defaults (0 and typemax(Int))
cb = mcmc_callback(logger=logger, stats=true, stats_options=(thin=10,))

Name Filtering

Use name_filter to control which parameters and statistics are logged:

cb = mcmc_callback(
    logger=logger,
    name_filter=(
        include=["mu", "sigma"],  # Only log these parameters
        exclude=["_internal"],     # Exclude matching names
        stats=true,                # Include step-level statistics
        extras=true,               # Include extra diagnostics
    ),
)

Complete Example

using AbstractMCMC
using TensorBoardLogger
using OnlineStats

logger = TBLogger("runs/full_example")

cb = mcmc_callback(
    logger=logger,
    stats=true,
    stats_options=(skip=50, thin=2),
    name_filter=(
        exclude=["_internal"],
        extras=true,
        hyperparams=true,
    ),
)

chain = sample(model, sampler, 10000; callback=cb)

Then view in TensorBoard:

tensorboard --logdir=runs/full_example

Navigate to localhost:6006 in your browser to see the dashboard. You'll see real-time plots of your parameter distributions, histograms, and other statistics as sampling progresses.

TensorBoard Time Series Tab

The Time Series tab provides detailed traces of parameter values throughout the sampling process.

TensorBoard Scalars Tab

The Scalars tab shows time series of parameter values and statistics over the sampling iterations.

TensorBoard Distributions Tab

The Distributions tab displays the marginal distributions of each parameter.

TensorBoard Histograms Tab

The Histograms tab shows the evolution of parameter distributions over time.

API Reference

AbstractMCMC.mcmc_callbackFunction
mcmc_callback(;
    logger,
    stats = nothing,
    stats_options = nothing,
    name_filter = nothing,
)

Create a TensorBoard logging callback. Requires TensorBoardLogger.jl to be loaded.

Arguments

  • logger: An AbstractLogger instance (e.g., TBLogger from TensorBoardLogger.jl)
  • stats: Statistics to collect. Can be:
    • nothing: No statistics (default)
    • true or :default: Use default statistics (Mean, Variance, KHist) - requires OnlineStats
    • An OnlineStat or tuple of OnlineStats - requires OnlineStats
  • stats_options: NamedTuple with thin, skip, window
  • name_filter: NamedTuple with include, exclude, stats, hyperparams

Examples

using TensorBoardLogger
lg = TBLogger("runs/exp")
cb = mcmc_callback(logger=lg)

# With default stats (requires OnlineStats)
using TensorBoardLogger, OnlineStats
lg = TBLogger("runs/exp")
cb = mcmc_callback(logger=lg, stats=true)
Note

This method is defined in the TensorBoardLogger extension. You must load TensorBoardLogger before using it: using TensorBoardLogger

source
AbstractMCMC.ParamsWithStatsType
ParamsWithStats{P,S,E}

A container for MCMC parameters, statistics, and extras.

All fields are stored as NamedTuples to ensure a tight, well-defined interface. Use Base.pairs(pws) to iterate over (name, value) pairs.

Fields

  • params::P: Parameter values as a NamedTuple
  • stats::S: Statistics as a NamedTuple (e.g., (lp=...,))
  • extras::E: Extra diagnostics as a NamedTuple

Example

pws = ParamsWithStats(model, sampler, transition, state; params=true, stats=true)
for (name, value) in Base.pairs(pws)
    println("$name: $value")
end

# Re-select to exclude stats:
pws2 = ParamsWithStats(pws; params=true, stats=false)
source

Default Values

stats_options defaults

OptionDefaultDescription
skip0Skip first n samples (burn-in)
thin0Use every nth sample (0=all)
windowtypemax(Int)Window size for rolling stats

name_filter defaults

OptionDefaultDescription
includeString[]Only log these (empty=all)
excludeString[]Don't log these
statsfalseInclude step-level statistics
extrasfalseInclude extra diagnostics

Implementing Custom Callbacks

Any callable with the following signature can be used as a callback:

function my_callback(rng, model, sampler, transition, state, iteration; kwargs...)
    # Your callback logic here
end

ParamsWithStats

ParamsWithStats is a container for extracting and iterating over MCMC parameters, statistics, and extras.

Basic Usage

# Extract params and stats from state
pws = ParamsWithStats(model, sampler, transition, state; params=true, stats=true)

# Iterate using Base.pairs
for (name, value) in Base.pairs(pws)
    @info name value
end

# Re-select to get only params
pws_params = ParamsWithStats(pws; params=true, stats=false, extras=false)

Overriding for Your Package

To provide meaningful variable names, override the extraction hooks:

# Option 1: Return Vector{<:Real} - default names (θ[1], θ[2], ...) will be used
function AbstractMCMC.getparams(state::MyState)
    return [state.mu, state.sigma]
end

# Option 2: Return named pairs - will be converted to NamedTuple
function AbstractMCMC.getparams(state::MyState)
    return ["μ" => state.mu, "σ" => state.sigma]
end

# Override getstats to return step-level statistics as NamedTuple
function AbstractMCMC.getstats(state::MyState)
    return (lp=state.logp, acceptance_rate=state.accept_rate)
end

The ParamsWithStats constructors normalize all inputs to NamedTuple:

  • Vector{<:Real} gets default θ[i] names
  • Vector{Pair} is converted to NamedTuple with the provided names
  • NamedTuple is used directly
stats vs extras

Use stats for values that change once per MCMC iteration (e.g., log probability, acceptance rate). Use extras for values that are constant across iterations (e.g., preconditioning matrix, number of particles) or that change multiple times within a single iteration (e.g., leapfrog phase points).

Usage in TensorBoard Callback

The TensorBoard callback uses ParamsWithStats with Base.pairs:

pws = ParamsWithStats(model, sampler, t, state; params=true, stats=true)
for (k, val) in Base.pairs(pws)
    @info "$k" val
end

Internals

Note

These types and methods are used internally. They are not part of the public API and may change or break at any time without notice.

Types

AbstractMCMC.NameFilterType
NameFilter(; include=Set{String}(), exclude=Set{String}())

A filter for variable names.

  • If include is non-empty, only names in include will pass the filter.
  • Names in exclude will be excluded.
  • Throws an error if include and exclude have overlapping elements.
source

OnlineStats Wrappers

When using statistics, AbstractMCMC provides wrappers that modify how samples are processed:

WrapperDescription
Skip(n, stat)Skip first n observations before fitting stat
Thin(n, stat)Only fit every n-th observation to stat
WindowStat(n, stat)Use a rolling window of n observations

These are applied automatically via stats_options, but can also be used directly if needed.