Parallel and Distributed Sampling
AbstractMCMC and AdvancedHMC support both parallel and distributed sampling.
Parallel Sampling (Multi-threaded)
To perform multi-threaded sampling of multiple chains, start Julia with the -t <n_threads> argument.
n_chains = 4
samples_and_stats = AbstractMCMC.sample(
model,
AdvancedHMC.NUTS(0.65),
AbstractMCMC.MCMCThreads(),
n_samples,
n_chains;
chain_type = Chains,
n_adapts = n_adapts,
init_params = [initial_θ for _ = 1:n_chains],
discard_initial = n_adapts,
)The key differences from single-chain sampling:
AbstractMCMC.MCMCThreads(): enables multi-threaded samplingn_chains: number of chains to sample in parallelinit_params: vector of initial parameters (one per chain)
Distributed Sampling (Multi-process)
To perform distributed sampling, start Julia with the -p <n_processes> argument.
Ensure all functions and modules are available on all processes using @everywhere:
@everywhere begin
using JuliaBUGS, LogDensityProblems, AbstractMCMC, AdvancedHMC, MCMCChains, ADTypes, ReverseDiff
# Define any custom functions here
# Use `@bugs_primitive` to register functions for use in the model
end
n_chains = nprocs() - 1 # use all worker processes
samples_and_stats = AbstractMCMC.sample(
model,
AdvancedHMC.NUTS(0.65),
AbstractMCMC.MCMCDistributed(),
n_samples,
n_chains;
chain_type = Chains,
n_adapts = n_adapts,
init_params = [initial_θ for _ = 1:n_chains],
discard_initial = n_adapts,
progress = false, # progress logging can cause issues in distributed mode
)The key differences:
AbstractMCMC.MCMCDistributed(): enables distributed samplingprogress = false: recommended to avoid TTY issues in distributed settings