Callbacks

AbstractMCMC provides a unified callback API for monitoring and logging MCMC sampling.

Basic Usage

The mcmc_callback function is the main entry point for creating callbacks:

using AbstractMCMC

struct MyModel <: AbstractMCMC.AbstractModel end

struct MySampler <: AbstractMCMC.AbstractSampler end

function AbstractMCMC.step(rng, ::MyModel, ::MySampler, state=nothing; kwargs...)
   # all samples are zero
   return 0.0, state
end

model, sampler = MyModel(), MySampler()

# Simple callback with a function
cb = mcmc_callback() do rng, model, sampler, transition, state, iteration
    println("Iteration: $iteration")
end

chain = sample(model, sampler, 10; callback=cb)

Combining Multiple Callbacks

Pass multiple callbacks to mcmc_callback to combine them:

cb1 = (args...; kwargs...) -> println("Callback 1")
cb2 = (args...; kwargs...) -> println("Callback 2")

cb = mcmc_callback(cb1, cb2)

You can also add callbacks dynamically using BangBang.push!!:

using BangBang

cb = mcmc_callback(cb1)
cb = push!!(cb, cb2)

TensorBoard Logging

TensorBoard logging requires TensorBoardLogger. Statistics collection also requires OnlineStats.

Basic Logging (No Statistics)

using AbstractMCMC
using TensorBoardLogger

logger = TBLogger("runs/experiment1")
cb = mcmc_callback(logger=logger)

chain = sample(model, sampler, 1000; callback=cb)

Logging with Statistics

To collect running statistics (mean, variance, histograms), load OnlineStats and use the stats argument:

using AbstractMCMC
using TensorBoardLogger
using OnlineStats

logger = TBLogger("runs/experiment1")

# Use default statistics (Mean, Variance, KHist)
cb = mcmc_callback(logger=logger, stats=true)

# Or specify custom statistics
cb = mcmc_callback(
    logger=logger,
    stats=(Mean(), Variance(), KHist(50)),
)
Note

If you request statistics without loading OnlineStats, you will get a helpful error: "Statistics collection requires OnlineStats.jl. Please load OnlineStats before enabling statistics."

Stats Processing Options

Control how samples are processed before computing statistics with stats_options:

cb = mcmc_callback(
    logger=logger,
    stats=true,
    stats_options=(
        skip=100,    # Skip first 100 samples (burn-in)
        thin=5,      # Use every 5th sample
        window=1000, # Rolling window of 1000 samples
    ),
)

Options merge with defaults, so you only need to specify what you want to change:

# Only change thin, skip and window use defaults (0 and typemax(Int))
cb = mcmc_callback(logger=logger, stats=true, stats_options=(thin=10,))

Name Filtering

Use name_filter to control which parameters and statistics are logged:

cb = mcmc_callback(
    logger=logger,
    name_filter=(
        include=["mu", "sigma"],  # Only log these parameters
        exclude=["_internal"],     # Exclude matching names
        extras=true,               # Include extra stats (log density, etc.)
        hyperparams=true,          # Include hyperparameters (logged once)
    ),
)

Complete Example

using AbstractMCMC
using TensorBoardLogger
using OnlineStats

logger = TBLogger("runs/full_example")

cb = mcmc_callback(
    logger=logger,
    stats=true,
    stats_options=(skip=50, thin=2),
    name_filter=(
        exclude=["_internal"],
        extras=true,
        hyperparams=true,
    ),
)

chain = sample(model, sampler, 10000; callback=cb)

Then view in TensorBoard:

tensorboard --logdir=runs/full_example

Navigate to localhost:6006 in your browser to see the dashboard. You'll see real-time plots of your parameter distributions, histograms, and other statistics as sampling progresses.

TensorBoard Time Series Tab

The Time Series tab provides detailed traces of parameter values throughout the sampling process.

TensorBoard Scalars Tab

The Scalars tab shows time series of parameter values and statistics over the sampling iterations.

TensorBoard Distributions Tab

The Distributions tab displays the marginal distributions of each parameter.

TensorBoard Histograms Tab

The Histograms tab shows the evolution of parameter distributions over time.

API Reference

Main Functions

AbstractMCMC.mcmc_callbackFunction
mcmc_callback(;
    logger,
    stats = nothing,
    stats_options = nothing,
    name_filter = nothing,
)

Create a TensorBoard logging callback. Requires TensorBoardLogger.jl to be loaded.

Arguments

  • logger: An AbstractLogger instance (e.g., TBLogger from TensorBoardLogger.jl)
  • stats: Statistics to collect. Can be:
    • nothing: No statistics (default)
    • true or :default: Use default statistics (Mean, Variance, KHist) - requires OnlineStats
    • An OnlineStat or tuple of OnlineStats - requires OnlineStats
  • stats_options: NamedTuple with thin, skip, window
  • name_filter: NamedTuple with include, exclude, extras, hyperparams

Examples

using TensorBoardLogger
lg = TBLogger("runs/exp")
cb = mcmc_callback(logger=lg)

# With default stats (requires OnlineStats)
using TensorBoardLogger, OnlineStats
lg = TBLogger("runs/exp")
cb = mcmc_callback(logger=lg, stats=true)
Note

This method is defined in the TensorBoardLogger extension. You must load TensorBoardLogger before using it: using TensorBoardLogger

source

Default Values

stats_options defaults

OptionDefaultDescription
skip0Skip first n samples (burn-in)
thin0Use every nth sample (0=all)
windowtypemax(Int)Window size for rolling stats

name_filter defaults

OptionDefaultDescription
includeString[]Only log these (empty=all)
excludeString[]Don't log these
extrasfalseInclude extra stats
hyperparamsfalseInclude hyperparameters

Implementing Custom Callbacks

Any callable with the following signature can be used as a callback:

function my_callback(rng, model, sampler, transition, state, iteration; kwargs...)
    # Your callback logic here
end

Internals

Note

These types and methods are used internally. They are not part of the public API and may change or break at any time without notice.

Types

AbstractMCMC.NameFilterType
NameFilter(; include=Set{String}(), exclude=Set{String}())

A filter for variable names.

  • If include is non-empty, only names in include will pass the filter.
  • Names in exclude will be excluded.
  • Throws an error if include and exclude have overlapping elements.
source

OnlineStats Wrappers

When using statistics, AbstractMCMC provides wrappers that modify how samples are processed:

WrapperDescription
Skip(n, stat)Skip first n observations before fitting stat
Thin(n, stat)Only fit every n-th observation to stat
WindowStat(n, stat)Use a rolling window of n observations

These are applied automatically via stats_options, but can also be used directly if needed.

Internal Functions

The unified _names_and_values function extracts all relevant data from a sampler state:

for (name, value) in AbstractMCMC._names_and_values(
    model, sampler, transition, state;
    params=true,
    hyperparams=false,
    extra=false,
)
    println("$name = $value")
end

Samplers can override AbstractMCMC.getparams(state) and AbstractMCMC.getstats(state) to provide custom information extraction.