Part of the API of DynamicPPL is defined in the more lightweight interface package AbstractPPL.jl and reexported here.
A core component of DynamicPPL is the @model
macro. It can be used to define probabilistic models in an intuitive way by specifying random variables and their distributions with ~
statements. These statements are rewritten by @model
as calls of internal functions for sampling the variables and computing their log densities.
— Macro@model(expr[, warn = false])
Macro to specify a probabilistic model.
If warn
is true
, a warning is displayed if internal variable names are used in the model definition.
Model definition:
@model function model(x, y = 42)
To generate a Model
, call model(xvalue)
or model(xvalue, yvalue)
A Model
can be created by calling the model function, as defined by @model
— Typestruct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext}
A Model
struct with model evaluation function of type F
, arguments of names argnames
types Targs
, default arguments of names defaultnames
with types Tdefaults
, missing arguments missings
, and evaluation context of type Ctx
Here argnames
, defaultargnames
, and missings
are tuples of symbols, e.g. (:a, :b)
. context
is by default DefaultContext()
An argument with a type of Missing
will be in missings
by default. However, in non-traditional use-cases missings
can be defined differently. All variables in missings
are treated as random variables rather than observations.
The default arguments are used internally when constructing instances of the same model with different arguments.
julia> Model(f, (x = 1.0, y = 2.0))
Model{typeof(f),(:x, :y),(),(),Tuple{Float64,Float64},Tuple{}}(f, (x = 1.0, y = 2.0), NamedTuple())
julia> Model(f, (x = 1.0, y = 2.0), (x = 42,))
Model{typeof(f),(:x, :y),(:x,),(),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition of missings
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
s are callable structs.
— Method(model::Model)([rng, varinfo, sampler, context])
Sample from the model
using the sampler
with random number generator rng
and the context
, and store the sample and log joint probability in varinfo
The method resets the log joint probability of varinfo
and increases the evaluation number of sampler
Basic properties of a model can be accessed with getargnames
, getmissings
, and nameof
— Methodnameof(model::Model)
Get the name of the model
as Symbol
— Functiongetargnames(model::Model)
Get a tuple of the argument names of the model
— Functiongetmissings(model::Model)
Get a tuple of the names of the missing arguments of the model
With rand
one can draw samples from the prior distribution of a Model
— Functionrand([rng=Random.default_rng()], [T=NamedTuple], model::Model)
Generate a sample of type T
from the prior distribution of the model
One can also evaluate the log prior, log likelihood, and log joint probability.
— Functionlogprior(model::Model, varinfo::AbstractVarInfo)
Return the log prior probability of variables varinfo
for the probabilistic model
See also logjoint
and loglikelihood
logprior(model::Model, chain::AbstractMCMC.AbstractChains)
Return an array of log prior probabilities evaluated at each sample in an MCMC chain
julia> using MCMCChains, Distributions
julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);
julia> logprior(demo_model([1., 2.]), chain);
logprior(model::Model, θ)
Return the log prior probability of variables θ
for the probabilistic model
See also logjoint
and loglikelihood
julia> @model function demo(x)
m ~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m, 1.0)
demo (generic function with 2 methods)
julia> # Using a `NamedTuple`.
logprior(demo([1.0]), (m = 100.0, ))
julia> # Using a `OrderedDict`.
logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0))
julia> # Truth.
logpdf(Normal(), 100.0)
— Functionloglikelihood(model::Model, varinfo::AbstractVarInfo)
Return the log likelihood of variables varinfo
for the probabilistic model
loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
Return an array of log likelihoods evaluated at each sample in an MCMC chain
julia> using MCMCChains, Distributions
julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);
julia> loglikelihood(demo_model([1., 2.]), chain);
loglikelihood(model::Model, θ)
Return the log likelihood of variables θ
for the probabilistic model
See also logjoint
and logprior
julia> @model function demo(x)
m ~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m, 1.0)
demo (generic function with 2 methods)
julia> # Using a `NamedTuple`.
loglikelihood(demo([1.0]), (m = 100.0, ))
julia> # Using a `OrderedDict`.
loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0))
julia> # Truth.
logpdf(Normal(100.0, 1.0), 1.0)
— Functionlogjoint(model::Model, varinfo::AbstractVarInfo)
Return the log joint probability of variables varinfo
for the probabilistic model
See logprior
and loglikelihood
logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
Return an array of log joint probabilities evaluated at each sample in an MCMC chain
julia> using MCMCChains, Distributions
julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);
julia> logjoint(demo_model([1., 2.]), chain);
logjoint(model::Model, θ)
Return the log joint probability of variables θ
for the probabilistic model
See logprior
and loglikelihood
julia> @model function demo(x)
m ~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m, 1.0)
demo (generic function with 2 methods)
julia> # Using a `NamedTuple`.
logjoint(demo([1.0]), (m = 100.0, ))
julia> # Using a `OrderedDict`.
logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0))
julia> # Truth.
logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0)
LogDensityProblems.jl interface
The LogDensityProblems.jl interface is also supported by wrapping a Model
in a DynamicPPL.LogDensityFunction
— TypeLogDensityFunction(
A struct which contains a model, along with all the information necessary to:
- calculate its log density at a given point;
- and if
is provided, calculate the gradient of the log density at
that point.
At its most basic level, a LogDensityFunction wraps the model together with its the type of varinfo to be used, as well as the evaluation context. These must be known in order to calculate the log density (using DynamicPPL.evaluate!!
If the adtype
keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the gradient of the log density. Note that preparing a LogDensityFunction
with an AD type AutoBackend()
requires the AD backend itself to have been loaded (e.g. with import Backend
implements the LogDensityProblems.jl interface. If adtype
is nothing, then only logdensity
is implemented. If adtype
is a concrete AD backend type, then logdensity_and_gradient
is also implemented.
: model used for evaluationvarinfo
: varinfo used for evaluationcontext
: context used for evaluation; ifnothing
will be used when applicableadtype
: AD type used for evaluation of log density gradient. Ifnothing
, no gradient can be calculatedprep
: (internal use only) gradient preparation object for the model
julia> using Distributions
julia> using DynamicPPL: LogDensityFunction, contextualize
julia> @model function demo(x)
m ~ Normal()
x ~ Normal(m, 1)
demo (generic function with 2 methods)
julia> model = demo(1.0);
julia> f = LogDensityFunction(model);
julia> # It implements the interface of LogDensityProblems.jl.
using LogDensityProblems
julia> LogDensityProblems.logdensity(f, [0.0])
julia> LogDensityProblems.dimension(f)
julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
f = LogDensityFunction(model, SimpleVarInfo(model));
julia> LogDensityProblems.logdensity(f, [0.0])
julia> # This also respects the context in `model`.
f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model));
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
julia> # If we also need to calculate the gradient, we can specify an AD backend.
import ForwardDiff, ADTypes
julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());
julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
(-2.3378770664093453, [1.0])
Condition and decondition
A Model
can be conditioned on a set of observations with AbstractPPL.condition
or its alias |
— Methodmodel | (x = 1.0, ...)
Return a Model
which now treats variables on the right-hand side as observations.
See condition
for more information and examples.
— Functioncondition(model::Model; values...)
condition(model::Model, values::NamedTuple)
Return a Model
which now treats the variables in values
as observations.
See also: decondition
, conditioned
This does currently not work with variables that are provided to the model as arguments, e.g. @model function demo(x) ... end
means that condition
will not affect the variable x
Therefore if one wants to make use of condition
and decondition
one should not be specifying any random variables as arguments.
This is done for the sake of backwards compatibility.
Simple univariate model
julia> using Distributions
julia> @model function demo()
m ~ Normal()
x ~ Normal(m, 1)
return (; m=m, x=x)
demo (generic function with 2 methods)
julia> model = demo();
julia> m, x = model(); (m ≠ 1.0 && x ≠ 100.0)
julia> # Create a new instance which treats `x` as observed
# with value `100.0`, and similarly for `m=1.0`.
conditioned_model = condition(model, x=100.0, m=1.0);
julia> m, x = conditioned_model(); (m == 1.0 && x == 100.0)
julia> # Let's only condition on `x = 100.0`.
conditioned_model = condition(model, x = 100.0);
julia> m, x =conditioned_model(); (m ≠ 1.0 && x == 100.0)
julia> # We can also use the nicer `|` syntax.
conditioned_model = model | (x = 100.0, );
julia> m, x = conditioned_model(); (m ≠ 1.0 && x == 100.0)
The above uses a NamedTuple
to hold the conditioning variables, which allows us to perform some additional optimizations; in many cases, the above has zero runtime-overhead.
But we can also use a Dict
, which offers more flexibility in the conditioning (see examples further below) but generally has worse performance than the NamedTuple
julia> conditioned_model_dict = condition(model, Dict(@varname(x) => 100.0));
julia> m, x = conditioned_model_dict(); (m ≠ 1.0 && x == 100.0)
julia> # There's also an option using `|` by letting the right-hand side be a tuple
# with elements of type `Pair{<:VarName}`, i.e. `vn => value` with `vn isa VarName`.
conditioned_model_dict = model | (@varname(x) => 100.0, );
julia> m, x = conditioned_model_dict(); (m ≠ 1.0 && x == 100.0)
Condition only a part of a multivariate variable
Not only can be condition on multivariate random variables, but we can also use the standard mechanism of setting something to missing
in the call to condition
to only condition on a part of the variable.
julia> @model function demo_mv(::Type{TV}=Float64) where {TV}
m = Vector{TV}(undef, 2)
m[1] ~ Normal()
m[2] ~ Normal()
return m
demo_mv (generic function with 4 methods)
julia> model = demo_mv();
julia> conditioned_model = condition(model, m = [missing, 1.0]);
julia> # (✓) `m[1]` sampled while `m[2]` is fixed
m = conditioned_model(); (m[1] ≠ 1.0 && m[2] == 1.0)
Intuitively one might also expect to be able to write model | (m[1] = 1.0, )
. Unfortunately this is not supported as it has the potential of increasing compilation times but without offering any benefit with respect to runtime:
julia> # (×) `m[2]` is not set to 1.0.
m = condition(model, var"m[2]" = 1.0)(); m[2] == 1.0
But you can do this if you use a Dict
as the underlying storage instead:
julia> # Alternatives:
# - `model | (@varname(m[2]) => 1.0,)`
# - `condition(model, Dict(@varname(m[2] => 1.0)))`
# (✓) `m[2]` is set to 1.0.
m = condition(model, @varname(m[2]) => 1.0)(); (m[1] ≠ 1.0 && m[2] == 1.0)
Nested models
of course also supports the use of nested models through the use of to_submodel
julia> @model demo_inner() = m ~ Normal()
demo_inner (generic function with 2 methods)
julia> @model function demo_outer()
# By default, `to_submodel` prefixes the variables using the left-hand side of `~`.
inner ~ to_submodel(demo_inner())
return inner
demo_outer (generic function with 2 methods)
julia> model = demo_outer();
julia> model() ≠ 1.0
julia> # To condition the variable inside `demo_inner` we need to refer to it as `inner.m`.
conditioned_model = model | (var"inner.m" = 1.0, );
julia> conditioned_model()
julia> # However, it's not possible to condition `inner` directly.
conditioned_model_fail = model | (inner = 1.0, );
julia> conditioned_model_fail()
ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported
And similarly when using Dict
julia> conditioned_model_dict = model | (@varname(var"inner.m") => 1.0);
julia> conditioned_model_dict()
— Functionconditioned(model::Model)
Return the conditioned values in model
julia> using Distributions
julia> using DynamicPPL: conditioned, contextualize
julia> @model function demo()
m ~ Normal()
x ~ Normal(m, 1)
demo (generic function with 2 methods)
julia> m = demo();
julia> # Returns all the variables we have conditioned on + their values.
conditioned(condition(m, x=100.0, m=1.0))
(x = 100.0, m = 1.0)
julia> # Nested ones also work (note that `PrefixContext` does nothing to the result).
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0);
julia> conditioned(cm)
(x = 100.0, m = 1.0)
julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed,
# `a.m` is treated as a random variable.
1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}:
julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation.
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((var"a.m"=1.0,)))), x=100.0);
julia> conditioned(cm).x
julia> conditioned(cm).var"a.m"
julia> keys(VarInfo(cm)) # No variables are sampled
Return NamedTuple
of values that are conditioned on under context`.
Note that this will recursively traverse the context stack and return a merged version of the condition values.
Similarly, one can specify with AbstractPPL.decondition
that certain, or all, random variables are not observed.
— Functiondecondition(model::Model)
decondition(model::Model, variables...)
Return a Model
for which variables...
are not considered observations. If no variables
are provided, then all variables currently considered observations will no longer be.
This is essentially the inverse of condition
. This also means that it suffers from the same limitiations.
Note that currently we only support variables
to take on explicit values provided to condition
julia> using Distributions
julia> @model function demo()
m ~ Normal()
x ~ Normal(m, 1)
return (; m=m, x=x)
demo (generic function with 2 methods)
julia> conditioned_model = condition(demo(), m = 1.0, x = 10.0);
julia> conditioned_model()
(m = 1.0, x = 10.0)
julia> # By specifying the `VarName` to `decondition`.
model = decondition(conditioned_model, @varname(m));
julia> (m, x) = model(); (m ≠ 1.0 && x == 10.0)
julia> # When `NamedTuple` is used as the underlying, you can also provide
# the symbol directly (though the `@varname` approach is preferable if
# if the variable is known at compile-time).
model = decondition(conditioned_model, :m);
julia> (m, x) = model(); (m ≠ 1.0 && x == 10.0)
julia> # `decondition` multiple at once:
(m, x) = decondition(model, :m, :x)(); (m ≠ 1.0 && x ≠ 10.0)
julia> # `decondition` without any symbols will `decondition` all variables.
(m, x) = decondition(model)(); (m ≠ 1.0 && x ≠ 10.0)
julia> # Usage of `Val` to perform `decondition` at compile-time if possible
# is also supported.
model = decondition(conditioned_model, Val{:m}());
julia> (m, x) = model(); (m ≠ 1.0 && x == 10.0)
Similarly when using a Dict
julia> conditioned_model_dict = condition(demo(), @varname(m) => 1.0, @varname(x) => 10.0);
julia> conditioned_model_dict()
(m = 1.0, x = 10.0)
julia> deconditioned_model_dict = decondition(conditioned_model_dict, @varname(m));
julia> (m, x) = deconditioned_model_dict(); m ≠ 1.0 && x == 10.0
But, as mentioned, decondition
is only supported for variables explicitly provided to condition
julia> @model function demo_mv(::Type{TV}=Float64) where {TV}
m = Vector{TV}(undef, 2)
m[1] ~ Normal()
m[2] ~ Normal()
return m
demo_mv (generic function with 4 methods)
julia> model = demo_mv();
julia> conditioned_model = condition(model, @varname(m) => [1.0, 2.0]);
julia> conditioned_model()
2-element Vector{Float64}:
julia> deconditioned_model = decondition(conditioned_model, @varname(m[1]));
julia> deconditioned_model() # (×) `m[1]` is still conditioned
2-element Vector{Float64}:
julia> # (✓) this works though
deconditioned_model_2 = deconditioned_model | (@varname(m[1]) => missing);
julia> m = deconditioned_model_2(); (m[1] ≠ 1.0 && m[2] == 2.0)
Fixing and unfixing
We can also fix a collection of variables in a Model
to certain using fix
This might seem quite similar to the aforementioned condition
and its siblings, but they are indeed different operations:
ed variables are considered to be observations, and are thus included in the computationlogjoint
, but not inlogprior
ed variables are considered to be constant, and are thus not included in any log-probability computations.
The differences are more clearly spelled out in the docstring of fix
— Functionfix(model::Model; values...)
fix(model::Model, values::NamedTuple)
Return a Model
which now treats the variables in values
as fixed.
Simple univariate model
julia> using Distributions
julia> @model function demo()
m ~ Normal()
x ~ Normal(m, 1)
return (; m=m, x=x)
demo (generic function with 2 methods)
julia> model = demo();
julia> m, x = model(); (m ≠ 1.0 && x ≠ 100.0)
julia> # Create a new instance which treats `x` as observed
# with value `100.0`, and similarly for `m=1.0`.
fixed_model = fix(model, x=100.0, m=1.0);
julia> m, x = fixed_model(); (m == 1.0 && x == 100.0)
julia> # Let's only fix on `x = 100.0`.
fixed_model = fix(model, x = 100.0);
julia> m, x = fixed_model(); (m ≠ 1.0 && x == 100.0)
The above uses a NamedTuple
to hold the fixed variables, which allows us to perform some additional optimizations; in many cases, the above has zero runtime-overhead.
But we can also use a Dict
, which offers more flexibility in the fixing (see examples further below) but generally has worse performance than the NamedTuple
julia> fixed_model_dict = fix(model, Dict(@varname(x) => 100.0));
julia> m, x = fixed_model_dict(); (m ≠ 1.0 && x == 100.0)
julia> # Alternative: pass `Pair{<:VarName}` as positional argument.
fixed_model_dict = fix(model, @varname(x) => 100.0, );
julia> m, x = fixed_model_dict(); (m ≠ 1.0 && x == 100.0)
Fix only a part of a multivariate variable
We can not only fix multivariate random variables, but we can also use the standard mechanism of setting something to missing
in the call to fix
to only fix a part of the variable.
julia> @model function demo_mv(::Type{TV}=Float64) where {TV}
m = Vector{TV}(undef, 2)
m[1] ~ Normal()
m[2] ~ Normal()
return m
demo_mv (generic function with 4 methods)
julia> model = demo_mv();
julia> fixed_model = fix(model, m = [missing, 1.0]);
julia> # (✓) `m[1]` sampled while `m[2]` is fixed
m = fixed_model(); (m[1] ≠ 1.0 && m[2] == 1.0)
Intuitively one might also expect to be able to write something like fix(model, var"m[1]" = 1.0, )
. Unfortunately this is not supported as it has the potential of increasing compilation times but without offering any benefit with respect to runtime:
julia> # (×) `m[2]` is not set to 1.0.
m = fix(model, var"m[2]" = 1.0)(); m[2] == 1.0
But you can do this if you use a Dict
as the underlying storage instead:
julia> # Alternative: `fix(model, Dict(@varname(m[2] => 1.0)))`
# (✓) `m[2]` is set to 1.0.
m = fix(model, @varname(m[2]) => 1.0)(); (m[1] ≠ 1.0 && m[2] == 1.0)
Nested models
of course also supports the use of nested models through the use of to_submodel
, similar to condition
julia> @model demo_inner() = m ~ Normal()
demo_inner (generic function with 2 methods)
julia> @model function demo_outer()
inner ~ to_submodel(demo_inner())
return inner
demo_outer (generic function with 2 methods)
julia> model = demo_outer();
julia> model() ≠ 1.0
julia> fixed_model = fix(model, var"inner.m" = 1.0, );
julia> fixed_model()
However, unlike condition
, fix
can also be used to fix the return-value of the submodel:
julia> fixed_model = fix(model, inner = 2.0,);
julia> fixed_model()
And similarly when using Dict
julia> fixed_model_dict = fix(model, @varname(var"inner.m") => 1.0);
julia> fixed_model_dict()
julia> fixed_model_dict = fix(model, @varname(inner) => 2.0);
julia> fixed_model_dict()
Difference from condition
A very similar functionality is also provided by condition
which, not surprisingly, conditions variables instead of fixing them. The only difference between fixing and conditioning is as follows:
ed variables are considered to be observations, and are thus included in the computationlogjoint
, but not inlogprior
ed variables are considered to be constant, and are thus not included in any log-probability computations.
julia> @model function demo()
m ~ Normal()
x ~ Normal(m, 1)
return (; m=m, x=x)
demo (generic function with 2 methods)
julia> model = demo();
julia> model_fixed = fix(model, m = 1.0);
julia> model_conditioned = condition(model, m = 1.0);
julia> logjoint(model_fixed, (x=1.0,))
julia> # Different!
logjoint(model_conditioned, (x=1.0,))
julia> # And the difference is the missing log-probability of `m`:
logjoint(model_fixed, (x=1.0,)) + logpdf(Normal(), 1.0) == logjoint(model_conditioned, (x=1.0,))
fix([context::AbstractContext,] values::NamedTuple)
fix([context::AbstractContext]; values...)
Return FixedContext
with values
and context
if values
is non-empty, otherwise return context
which is DefaultContext
by default.
See also: unfix
— Functionfixed(model::Model)
Return the fixed values in model
julia> using Distributions
julia> using DynamicPPL: fixed, contextualize
julia> @model function demo()
m ~ Normal()
x ~ Normal(m, 1)
demo (generic function with 2 methods)
julia> m = demo();
julia> # Returns all the variables we have fixed on + their values.
fixed(fix(m, x=100.0, m=1.0))
(x = 100.0, m = 1.0)
julia> # Nested ones also work (note that `PrefixContext` does nothing to the result).
cm = fix(contextualize(m, PrefixContext{:a}(fix(m=1.0))), x=100.0);
julia> fixed(cm)
(x = 100.0, m = 1.0)
julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed,
# `a.m` is treated as a random variable.
1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}:
julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation.
cm = fix(contextualize(m, PrefixContext{:a}(fix(var"a.m"=1.0))), x=100.0);
julia> fixed(cm).x
julia> fixed(cm).var"a.m"
julia> keys(VarInfo(cm)) # <= no variables are sampled
Return the values that are fixed under context
Note that this will recursively traverse the context stack and return a merged version of the fix values.
The difference between fix
and condition
is described in the docstring of fix
Similarly, we can unfix
variables, i.e. return them to their original meaning:
— Functionunfix(model::Model)
unfix(model::Model, variables...)
Return a Model
for which variables...
are not considered fixed. If no variables
are provided, then all variables currently considered fixed will no longer be.
This is essentially the inverse of fix
. This also means that it suffers from the same limitiations.
Note that currently we only support variables
to take on explicit values provided to fix
julia> using Distributions
julia> @model function demo()
m ~ Normal()
x ~ Normal(m, 1)
return (; m=m, x=x)
demo (generic function with 2 methods)
julia> fixed_model = fix(demo(), m = 1.0, x = 10.0);
julia> fixed_model()
(m = 1.0, x = 10.0)
julia> # By specifying the `VarName` to `unfix`.
model = unfix(fixed_model, @varname(m));
julia> (m, x) = model(); (m ≠ 1.0 && x == 10.0)
julia> # When `NamedTuple` is used as the underlying, you can also provide
# the symbol directly (though the `@varname` approach is preferable if
# if the variable is known at compile-time).
model = unfix(fixed_model, :m);
julia> (m, x) = model(); (m ≠ 1.0 && x == 10.0)
julia> # `unfix` multiple at once:
(m, x) = unfix(model, :m, :x)(); (m ≠ 1.0 && x ≠ 10.0)
julia> # `unfix` without any symbols will `unfix` all variables.
(m, x) = unfix(model)(); (m ≠ 1.0 && x ≠ 10.0)
julia> # Usage of `Val` to perform `unfix` at compile-time if possible
# is also supported.
model = unfix(fixed_model, Val{:m}());
julia> (m, x) = model(); (m ≠ 1.0 && x == 10.0)
Similarly when using a Dict
julia> fixed_model_dict = fix(demo(), @varname(m) => 1.0, @varname(x) => 10.0);
julia> fixed_model_dict()
(m = 1.0, x = 10.0)
julia> unfixed_model_dict = unfix(fixed_model_dict, @varname(m));
julia> (m, x) = unfixed_model_dict(); m ≠ 1.0 && x == 10.0
But, as mentioned, unfix
is only supported for variables explicitly provided to fix
julia> @model function demo_mv(::Type{TV}=Float64) where {TV}
m = Vector{TV}(undef, 2)
m[1] ~ Normal()
m[2] ~ Normal()
return m
demo_mv (generic function with 4 methods)
julia> model = demo_mv();
julia> fixed_model = fix(model, @varname(m) => [1.0, 2.0]);
julia> fixed_model()
2-element Vector{Float64}:
julia> unfixed_model = unfix(fixed_model, @varname(m[1]));
julia> unfixed_model() # (×) `m[1]` is still fixed
2-element Vector{Float64}:
julia> # (✓) this works though
unfixed_model_2 = fix(unfixed_model, @varname(m[1]) => missing);
julia> m = unfixed_model_2(); (m[1] ≠ 1.0 && m[2] == 2.0)
unfix(context::AbstractContext, syms...)
Return context
but with syms
no longer fixed.
Note that this recursively traverses contexts, unfixing all along the way.
See also: fix
DynamicPPL provides functionality for generating samples from the posterior predictive distribution through the predict
function. This allows you to use posterior parameter samples to generate predictions for unobserved data points.
The predict
function has two main methods:
- For
- useful when you have a collection ofVarInfo
objects representing posterior samples. - For
(only available whenMCMCChains.jl
is loaded) - useful when you have posterior samples in the form of anMCMCChains.Chains
— Functionpredict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})
Generate samples from the posterior predictive distribution by evaluating model
at each set of parameter values provided in chain
. The number of posterior predictive samples matches the length of chain
. The returned AbstractVarInfo
s will contain both the posterior parameter values and the predicted values.
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
Sample from the posterior predictive distribution by executing model
with parameters fixed to each sample in chain
, and return the resulting Chains
The model
passed to predict
is often different from the one used to generate chain
. Typically, the model from which chain
originated treats certain variables as observed (i.e., data points), while the model you pass to predict
may mark these same variables as missing or unobserved. Calling predict
then leverages the previously inferred parameter values to simulate what new, unobserved data might look like, given your posterior beliefs.
For each parameter configuration in chain
- All random variables present in
are fixed to their sampled values. - Any variables not included in
are sampled from their prior distributions.
If include_all
is false
, the returned Chains
will contain only those variables that were not fixed by the samples in chain
. This is useful when you want to sample only new variables from the posterior predictive distribution.
using AbstractMCMC, Distributions, DynamicPPL, Random
@model function linear_reg(x, y, σ = 0.1)
β ~ Normal(0, 1)
for i in eachindex(y)
y[i] ~ Normal(β * x[i], σ)
# Generate synthetic chain using known ground truth parameter
ground_truth_β = 2.0
# Create chain of samples from a normal distribution centered on ground truth
β_chain = MCMCChains.Chains(
rand(Normal(ground_truth_β, 0.002), 1000), [:β,]
# Generate predictions for two test points
xs_test = [10.1, 10.2]
m_train = linear_reg(xs_test, fill(missing, length(xs_test)))
predictions = DynamicPPL.AbstractPPL.predict(
Random.default_rng(), m_train, β_chain
ys_pred = vec(mean(Array(predictions); dims=1))
# Check if predictions match expected values within tolerance
isapprox(ys_pred[1], ground_truth_β * xs_test[1], atol = 0.01),
isapprox(ys_pred[2], ground_truth_β * xs_test[2], atol = 0.01)
# output
(true, true)
Basic Usage
The typical workflow for posterior prediction involves:
- Fitting a model to observed data to obtain posterior samples
- Creating a new model instance with some variables marked as missing (unobserved)
- Using
to generate samples for these missing variables based on the posterior parameter samples
When using predict
with MCMCChains.Chains
, you can control which variables are included in the output with the include_all
(default): Include only newly predicted variablesinclude_all=true
: Include both parameters from the original chain and predicted variables
Models within models
One can include models and call another model inside the model function with left ~ to_submodel(model)
— Functionto_submodel(model::Model[, auto_prefix::Bool])
Return a model wrapper indicating that it is a sampleable model over the return-values.
This is mainly meant to be used on the right-hand side of a ~
operator to indicate that the model can be sampled from but not necessarily evaluated for its log density.
Note that some other operations that one typically associate with expressions of the form left ~ right
such as condition
, will also not work with to_submodel
To avoid variable names clashing between models, it is recommend leave argument auto_prefix
equal to true
. If one does not use automatic prefixing, then it's recommended to use prefix(::Model, input)
: the model to wrap.auto_prefix::Bool
: whether to automatically prefix the variables in the model using the left-hand side of the~
statement. Default:true
Simple example
julia> @model function demo1(x)
x ~ Normal()
return 1 + abs(x)
julia> @model function demo2(x, y)
a ~ to_submodel(demo1(x))
return y ~ Uniform(0, a)
When we sample from the model demo2(missing, 0.4)
random variable x
will be sampled:
julia> vi = VarInfo(demo2(missing, 0.4));
julia> @varname(var"a.x") in keys(vi)
The variable a
is not tracked. However, it will be assigned the return value of demo1
, and can be used in subsequent lines of the model, as shown above.
julia> @varname(a) in keys(vi)
We can check that the log joint probability of the model accumulated in vi
is correct:
julia> x = vi[@varname(var"a.x")];
julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
Without automatic prefixing
As mentioned earlier, by default, the auto_prefix
argument specifies whether to automatically prefix the variables in the submodel. If auto_prefix=false
, then the variables in the submodel will not be prefixed.
julia> @model function demo1(x)
x ~ Normal()
return 1 + abs(x)
julia> @model function demo2_no_prefix(x, z)
a ~ to_submodel(demo1(x), false)
return z ~ Uniform(-a, 1)
julia> vi = VarInfo(demo2_no_prefix(missing, 0.4));
julia> @varname(x) in keys(vi) # here we just use `x` instead of `a.x`
However, not using prefixing is generally not recommended as it can lead to variable name clashes unless one is careful. For example, if we're re-using the same model twice in a model, not using prefixing will lead to variable name clashes: However, one can manually prefix using the prefix(::Model, input)
julia> @model function demo2(x, y, z)
a ~ to_submodel(prefix(demo1(x), :sub1), false)
b ~ to_submodel(prefix(demo1(y), :sub2), false)
return z ~ Uniform(-a, b)
julia> vi = VarInfo(demo2(missing, missing, 0.4));
julia> @varname(var"sub1.x") in keys(vi)
julia> @varname(var"sub2.x") in keys(vi)
Variables a
and b
are not tracked, but are assigned the return values of the respective calls to demo1
julia> @varname(a) in keys(vi)
julia> @varname(b) in keys(vi)
We can check that the log joint probability of the model accumulated in vi
is correct:
julia> sub1_x = vi[@varname(var"sub1.x")];
julia> sub2_x = vi[@varname(var"sub2.x")];
julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x);
julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4);
julia> getlogp(vi) ≈ logprior + loglikelihood
Usage as likelihood is illegal
Note that it is illegal to use a to_submodel
model as a likelihood in another model:
julia> @model inner() = x ~ Normal()
inner (generic function with 2 methods)
julia> @model illegal_likelihood() = a ~ to_submodel(inner())
illegal_likelihood (generic function with 2 methods)
julia> model = illegal_likelihood() | (a = 1.0,);
julia> model()
ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported
Note that a [to_submodel](@ref)
is only sampleable; one cannot compute logpdf
for its realizations.
In the past, one would instead embed sub-models using @submodel
, which has been deprecated since the introduction of to_submodel(model)
— Macro@submodel model
@submodel ... = model
Run a Turing model
nested inside of a Turing model.
This is deprecated and will be removed in a future release. Use left ~ to_submodel(model)
instead (see to_submodel
julia> @model function demo1(x)
x ~ Normal()
return 1 + abs(x)
julia> @model function demo2(x, y)
@submodel a = demo1(x)
return y ~ Uniform(0, a)
When we sample from the model demo2(missing, 0.4)
random variable x
will be sampled:
julia> vi = VarInfo(demo2(missing, 0.4));
┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.
│ caller = ip:0x0
└ @ Core :-1
julia> @varname(x) in keys(vi)
Variable a
is not tracked since it can be computed from the random variable x
that was tracked when running demo1
julia> @varname(a) in keys(vi)
We can check that the log joint probability of the model accumulated in vi
is correct:
julia> x = vi[@varname(x)];
julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
@submodel prefix=... model
@submodel prefix=... ... = model
Run a Turing model
nested inside of a Turing model and add "prefix
." as a prefix to all random variables inside of the model
Valid expressions for prefix=...
: no prefix is used.prefix=true
: attempt to automatically determine the prefix from the left-hand side... = model
by first converting into aVarName
, and then callingSymbol
on this.prefix=expression
: results in the prefixSymbol(expression)
The prefix makes it possible to run the same Turing model multiple times while keeping track of all random variables correctly.
This is deprecated and will be removed in a future release. Use left ~ to_submodel(model)
instead (see to_submodel(model)
Example models
julia> @model function demo1(x)
x ~ Normal()
return 1 + abs(x)
julia> @model function demo2(x, y, z)
@submodel prefix="sub1" a = demo1(x)
@submodel prefix="sub2" b = demo1(y)
return z ~ Uniform(-a, b)
When we sample from the model demo2(missing, missing, 0.4)
random variables sub1.x
and sub2.x
will be sampled:
julia> vi = VarInfo(demo2(missing, missing, 0.4));
┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.
│ caller = ip:0x0
└ @ Core :-1
julia> @varname(var"sub1.x") in keys(vi)
julia> @varname(var"sub2.x") in keys(vi)
Variables a
and b
are not tracked since they can be computed from the random variables sub1.x
and sub2.x
that were tracked when running demo1
julia> @varname(a) in keys(vi)
julia> @varname(b) in keys(vi)
We can check that the log joint probability of the model accumulated in vi
is correct:
julia> sub1_x = vi[@varname(var"sub1.x")];
julia> sub2_x = vi[@varname(var"sub2.x")];
julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x);
julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4);
julia> getlogp(vi) ≈ logprior + loglikelihood
Different ways of setting the prefix
julia> @model inner() = x ~ Normal()
inner (generic function with 2 methods)
julia> # When `prefix` is unspecified, no prefix is used.
@model submodel_noprefix() = @submodel a = inner()
submodel_noprefix (generic function with 2 methods)
julia> @varname(x) in keys(VarInfo(submodel_noprefix()))
┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.
│ caller = ip:0x0
└ @ Core :-1
julia> # Explicitely don't use any prefix.
@model submodel_prefix_false() = @submodel prefix=false a = inner()
submodel_prefix_false (generic function with 2 methods)
julia> @varname(x) in keys(VarInfo(submodel_prefix_false()))
┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.
│ caller = ip:0x0
└ @ Core :-1
julia> # Automatically determined from `a`.
@model submodel_prefix_true() = @submodel prefix=true a = inner()
submodel_prefix_true (generic function with 2 methods)
julia> @varname(var"a.x") in keys(VarInfo(submodel_prefix_true()))
┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.
│ caller = ip:0x0
└ @ Core :-1
julia> # Using a static string.
@model submodel_prefix_string() = @submodel prefix="my prefix" a = inner()
submodel_prefix_string (generic function with 2 methods)
julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string()))
┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.
│ caller = ip:0x0
└ @ Core :-1
julia> # Using string interpolation.
@model submodel_prefix_interpolation() = @submodel prefix="$(nameof(inner()))" a = inner()
submodel_prefix_interpolation (generic function with 2 methods)
julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation()))
┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.
│ caller = ip:0x0
└ @ Core :-1
julia> # Or using some arbitrary expression.
@model submodel_prefix_expr() = @submodel prefix=1 + 2 a = inner()
submodel_prefix_expr (generic function with 2 methods)
julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr()))
┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.
│ caller = ip:0x0
└ @ Core :-1
julia> # (×) Automatic prefixing without a left-hand side expression does not work!
@model submodel_prefix_error() = @submodel prefix=true inner()
ERROR: LoadError: cannot automatically prefix with no left-hand side
- The choice
means that the prefixing will incur a runtime cost. This is also the case forprefix=true
, depending on whether the expression on the the right-hand side of... = model
requires runtime-information or not, e.g.x = model
will result in the static prefixx
, whilex[i] = model
will be resolved at runtime.
In the context of including models within models, it's also useful to prefix the variables in sub-models to avoid variable names clashing:
— Functionprefix(model::Model, x)
Return model
but with all random variables prefixed by x
If x
is known at compile-time, use Val{x}()
to avoid runtime overheads for prefixing.
julia> using DynamicPPL: prefix
julia> @model demo() = x ~ Dirac(1)
demo (generic function with 2 methods)
julia> rand(prefix(demo(), :my_prefix))
(var"my_prefix.x" = 1,)
julia> # One can also use `Val` to avoid runtime overheads.
rand(prefix(demo(), Val(:my_prefix)))
(var"my_prefix.x" = 1,)
Under the hood, to_submodel
makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else
— Methodreturned(model)
Return a model
wrapper indicating that it is a model over its return-values.
It is possible to manually increase (or decrease) the accumulated log density from within a model function.
— Macro@addlogprob!(ex)
Add the result of the evaluation of ex
to the joint log probability.
This macro allows you to include arbitrary terms in the likelihood
julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x);
julia> @model function demo(x)
μ ~ Normal()
@addlogprob! myloglikelihood(x, μ)
julia> x = [1.3, -2.1];
julia> loglikelihood(demo(x), (μ=0.2,)) ≈ myloglikelihood(x, 0.2)
and to reject samples:
julia> @model function demo(x)
m ~ MvNormal(zero(x), I)
if dot(m, x) < 0
@addlogprob! -Inf
# Exit the model evaluation early
x ~ MvNormal(m, I)
julia> logjoint(demo([-2.1]), (m=[0.2],)) == -Inf
The @addlogprob!
macro increases the accumulated log probability regardless of the evaluation context, i.e., regardless of whether you evaluate the log prior, the log likelihood or the log joint density. If you would like to avoid this behaviour you should check the evaluation context. It can be accessed with the internal variable __context__
. For instance, in the following example the log density is not accumulated when only the log prior is computed:
julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x);
julia> @model function demo(x)
μ ~ Normal()
if DynamicPPL.leafcontext(__context__) !== PriorContext()
@addlogprob! myloglikelihood(x, μ)
julia> x = [1.3, -2.1];
julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2)
julia> loglikelihood(demo(x), (μ=0.2,)) ≈ myloglikelihood(x, 0.2)
Return values of the model function for a collection of samples can be obtained with returned(model, chain)
— Methodreturned(model::Model, parameters::NamedTuple)
returned(model::Model, values, keys)
returned(model::Model, values, keys)
Execute model
with variables keys
set to values
and return the values returned by the model
If a NamedTuple
is given, keys=keys(parameters)
and values=values(parameters)
julia> using DynamicPPL, Distributions
julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
m_shifted ~ Normal(10, √s)
m = m_shifted - 10
for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
return (m, )
demo (generic function with 2 methods)
julia> model = demo(randn(10));
julia> parameters = (; s = 1.0, m_shifted=10.0);
julia> returned(model, parameters)
julia> returned(model, values(parameters), keys(parameters))
For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with pointwise_loglikelihoods
. Similarly, the log-densities of the priors using pointwise_prior_logdensities
or both, i.e. all variables, using pointwise_logdensities
— Functionpointwise_logdensities(model::Model, chain::Chains, keytype = String)
Runs model
on each sample in chain
returning a OrderedDict{String, Matrix{Float64}}
with keys corresponding to symbols of the variables, and values being matrices of shape (num_chains, num_samples)
specifies what the type of the keys used in the returned OrderedDict
are. Currently, only String
and VarName
are supported.
Say y
is a Vector
of n
i.i.d. Normal(μ, σ)
variables, with μ
and σ
both being <:Real
. Then the observe (i.e. when the left-hand side is an observation) statements can be implemented in three ways:
- using a
for i in eachindex(y)
y[i] ~ Normal(μ, σ)
- using
y .~ Normal(μ, σ)
- using
y ~ MvNormal(fill(μ, n), σ^2 * I)
In (1) and (2), y
will be treated as a collection of n
i.i.d. 1-dimensional variables, while in (3) y
will be treated as a single n-dimensional observation.
This is important to keep in mind, in particular if the computation is used for downstream computations.
From chain
julia> using MCMCChains
julia> @model function demo(xs, y)
s ~ InverseGamma(2, 3)
m ~ Normal(0, √s)
for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
y ~ Normal(m, √s)
demo (generic function with 2 methods)
julia> # Example observations.
model = demo([1.0, 2.0, 3.0], [4.0]);
julia> # A chain with 3 iterations.
chain = Chains(
reshape(1.:6., 3, 2),
[:s, :m]
julia> pointwise_logdensities(model, chain)
OrderedDict{String, Matrix{Float64}} with 6 entries:
"s" => [-0.802775; -1.38222; -2.09861;;]
"m" => [-8.91894; -7.51551; -7.46824;;]
"xs[1]" => [-5.41894; -5.26551; -5.63491;;]
"xs[2]" => [-2.91894; -3.51551; -4.13491;;]
"xs[3]" => [-1.41894; -2.26551; -2.96824;;]
"y" => [-0.918939; -1.51551; -2.13491;;]
julia> pointwise_logdensities(model, chain, String)
OrderedDict{String, Matrix{Float64}} with 6 entries:
"s" => [-0.802775; -1.38222; -2.09861;;]
"m" => [-8.91894; -7.51551; -7.46824;;]
"xs[1]" => [-5.41894; -5.26551; -5.63491;;]
"xs[2]" => [-2.91894; -3.51551; -4.13491;;]
"xs[3]" => [-1.41894; -2.26551; -2.96824;;]
"y" => [-0.918939; -1.51551; -2.13491;;]
julia> pointwise_logdensities(model, chain, VarName)
OrderedDict{VarName, Matrix{Float64}} with 6 entries:
s => [-0.802775; -1.38222; -2.09861;;]
m => [-8.91894; -7.51551; -7.46824;;]
xs[1] => [-5.41894; -5.26551; -5.63491;;]
xs[2] => [-2.91894; -3.51551; -4.13491;;]
xs[3] => [-1.41894; -2.26551; -2.96824;;]
y => [-0.918939; -1.51551; -2.13491;;]
Note that x .~ Dist()
will treat x
as a collection of independent observations rather than as a single observation.
julia> @model function demo(x)
x .~ Normal()
julia> m = demo([1.0, ]);
julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])])
julia> m = demo([1.0; 1.0]);
julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])]))
(-1.4189385332046727, -1.4189385332046727)
— Functionpointwise_loglikelihoods(model, chain[, keytype, context])
Compute the pointwise log-likelihoods of the model given the chain. This is the same as pointwise_logdensities(model, chain, context)
, but only including the likelihood terms. See also: pointwise_logdensities
— Functionpointwise_prior_logdensities(model, chain[, keytype, context])
Compute the pointwise log-prior-densities of the model given the chain. This is the same as pointwise_logdensities(model, chain, context)
, but only including the prior terms. See also: pointwise_logdensities
For converting a chain into a format that can more easily be fed into a Model
again, for example using condition
, you can use value_iterator_from_chain
— Functionvalue_iterator_from_chain(model::Model, chain)
value_iterator_from_chain(varinfo::AbstractVarInfo, chain)
Return an iterator over the values in chain
for each variable in model
julia> using MCMCChains, DynamicPPL, Distributions, StableRNGs
julia> rng = StableRNG(42);
julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
return s, m
demo_model (generic function with 2 methods)
julia> model = demo_model([1.0, 2.0]);
julia> chain = Chains(rand(rng, 10, 2, 3), [:s, :m]);
julia> iter = value_iterator_from_chain(model, chain);
julia> first(iter)
OrderedDict{VarName, Any} with 2 entries:
s => 0.580515
m => 0.739328
julia> collect(iter)
10×3 Matrix{OrderedDict{VarName, Any}}:
OrderedDict(s=>0.580515, m=>0.739328) … OrderedDict(s=>0.186047, m=>0.402423)
OrderedDict(s=>0.191241, m=>0.627342) OrderedDict(s=>0.776277, m=>0.166342)
OrderedDict(s=>0.971133, m=>0.637584) OrderedDict(s=>0.651655, m=>0.712044)
OrderedDict(s=>0.74345, m=>0.110359) OrderedDict(s=>0.469214, m=>0.104502)
OrderedDict(s=>0.170969, m=>0.598514) OrderedDict(s=>0.853546, m=>0.185399)
OrderedDict(s=>0.704776, m=>0.322111) … OrderedDict(s=>0.638301, m=>0.853802)
OrderedDict(s=>0.441044, m=>0.162285) OrderedDict(s=>0.852959, m=>0.0956922)
OrderedDict(s=>0.803972, m=>0.643369) OrderedDict(s=>0.245049, m=>0.871985)
OrderedDict(s=>0.772384, m=>0.646323) OrderedDict(s=>0.906603, m=>0.385502)
OrderedDict(s=>0.70882, m=>0.253105) OrderedDict(s=>0.413222, m=>0.953288)
julia> # This can be used to `condition` a `Model`.
conditioned_model = model | first(iter);
julia> conditioned_model() # <= results in same values as the `first(iter)` above
(0.5805148626851955, 0.7393275279160691)
Sometimes it can be useful to extract the priors of a model. This is the possible using extract_priors
— Functionextract_priors([rng::Random.AbstractRNG, ]model::Model)
Extract the priors from a model.
This is done by sampling from the model and recording the distributions that are used to generate the samples.
Because the extraction is done by execution of the model, there are several caveats:
- If one variable, say,
y ~ Normal(0, x)
, wherex ~ Normal()
is also a random variable, then the extracted prior will have different parameters in every extraction! - If the model does not have static support, say,
n ~ Categorical(1:10); x ~ MvNormmal(zeros(n), I)
, then the extracted priors themselves will be different between extractions, not just their parameters.
Both of these caveats are demonstrated below.
Changing parameters
julia> using Distributions, StableRNGs
julia> rng = StableRNG(42);
julia> @model function model_dynamic_parameters()
x ~ Normal(0, 1)
y ~ Normal(x, 1)
julia> model = model_dynamic_parameters();
julia> extract_priors(rng, model)[@varname(y)]
Normal{Float64}(μ=-0.6702516921145671, σ=1.0)
julia> extract_priors(rng, model)[@varname(y)]
Normal{Float64}(μ=1.3736306979834252, σ=1.0)
Changing support
julia> using LinearAlgebra, Distributions, StableRNGs
julia> rng = StableRNG(42);
julia> @model function model_dynamic_support()
n ~ Categorical(ones(10) ./ 10)
x ~ MvNormal(zeros(n), I)
julia> model = model_dynamic_support();
julia> length(extract_priors(rng, model)[@varname(x)])
julia> length(extract_priors(rng, model)[@varname(x)])
extract_priors(model::Model, varinfo::AbstractVarInfo)
Extract the priors from a model.
This is done by evaluating the model at the values present in varinfo
and recording the distributions that are present at each tilde statement.
Safe extraction of values from a given AbstractVarInfo
as they are seen in the model can be done using values_as_in_model
— Functionvalues_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext])
Get the values of varinfo
as they would be seen in the model.
More specifically, this method attempts to extract the realization as seen in the model. For example, x[1] ~ truncated(Normal(); lower=0)
will result in a realization that is compatible with truncated(Normal(); lower=0)
– i.e. one where the value of x[1]
is positive – regardless of whether varinfo
is working in unconstrained space.
Hence this method is a "safe" way of obtaining realizations in constrained space at the cost of additional model evaluations.
: model to extract realizations from.include_colon_eq::Bool
: whether to also include variables on the LHS of:=
: variable information to use for the extraction.context::AbstractContext
: base context to use for the extraction. Defaults toDynamicPPL.DefaultContext()
When VarInfo
The following demonstrates a common pitfall when working with VarInfo
and constrained variables.
julia> using Distributions, StableRNGs
julia> rng = StableRNG(42);
julia> @model function model_changing_support()
x ~ Bernoulli(0.5)
y ~ x == 1 ? Uniform(0, 1) : Uniform(11, 12)
julia> model = model_changing_support();
julia> # Construct initial type-stable `VarInfo`.
varinfo = VarInfo(rng, model);
julia> # Link it so it works in unconstrained space.
varinfo_linked = DynamicPPL.link(varinfo, model);
julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`.
# Flip `x` so we hit the other support of `y`.
θ = [!varinfo[@varname(x)], rand(rng)];
julia> # Update the `VarInfo` with the new values.
varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ);
julia> # Determine the expected support of `y`.
lb, ub = θ[1] == 1 ? (0, 1) : (11, 12)
(0, 1)
julia> # Approach 1: Convert back to constrained space using `invlink` and extract.
varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model);
julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions
# used in the very first model evaluation, hence the support of `y`
# is not updated even though `x` has changed.
lb ≤ first(varinfo_invlinked[@varname(y)]) ≤ ub
julia> # Approach 2: Extract realizations using `values_as_in_model`.
# (✓) `values_as_in_model` will re-run the model and extract
# the correct realization of `y` given the new values of `x`.
lb ≤ values_as_in_model(model, true, varinfo_linked)[@varname(y)] ≤ ub
— TypeA named distribution that carries the name of the random variable with it.
Testing Utilities
DynamicPPL provides several demo models and helpers for testing samplers in the DynamicPPL.TestUtils
— Functiontest_sampler(models, sampler, args...; kwargs...)
Test that sampler
produces correct marginal posterior means on each model in models
In short, this method iterates through models
, calls AbstractMCMC.sample
on the model
and sampler
to produce a chain
, and then checks marginal_mean_of_samples(chain, vn)
for every (leaf) varname vn
against the corresponding value returned by posterior_mean
for each model.
To change how comparison is done for a particular chain
type, one can overload marginal_mean_of_samples
for the corresponding type.
: A collection of instaces ofDynamicPPL.Model
to test on.sampler
: TheAbstractMCMC.AbstractSampler
to test.args...
: Arguments forwarded tosample
Keyword arguments
: A filter to apply tovarnames(model)
, allowing comparison for only a subset of the varnames.atol=1e-1
: Absolute tolerance used in@test
: Relative tolerance used in@test
: Keyword arguments forwarded tosample
— Functiontest_sampler_on_demo_models(meanfunction, sampler, args...; kwargs...)
Test sampler
on every model in DEMO_MODELS
This is just a proxy for test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)
— Functiontest_sampler_continuous(sampler, args...; kwargs...)
Test that sampler
produces the correct marginal posterior means on all models in demo_models
As of right now, this is just an alias for test_sampler_on_demo_models
— Functionmarginal_mean_of_samples(chain, varname)
Return the mean of variable represented by varname
in chain
— ConstantA collection of models corresponding to the posterior distribution defined by the generative process
s ~ InverseGamma(2, 3)
m ~ Normal(0, √s)
1.5 ~ Normal(m, √s)
2.0 ~ Normal(m, √s)
or by
s[1] ~ InverseGamma(2, 3)
s[2] ~ InverseGamma(2, 3)
m[1] ~ Normal(0, √s)
m[2] ~ Normal(0, √s)
1.5 ~ Normal(m[1], √s[1])
2.0 ~ Normal(m[2], √s[2])
These are examples of a Normal-InverseGamma conjugate prior with Normal likelihood, for which the posterior is known in closed form.
In particular, for the univariate model (the former one):
mean(s) == 49 / 24
mean(m) == 7 / 6
And for the multivariate one (the latter one):
mean(s[1]) == 19 / 8
mean(m[1]) == 3 / 4
mean(s[2]) == 8 / 3
mean(m[2]) == 1
For every demo model, one can define the true log prior, log likelihood, and log joint probabilities.
— Functionlogprior_true(model, args...)
Return the logprior
of model
for args
This should generally be implemented by hand for every specific model
See also: logjoint_true
, loglikelihood_true
— Functionloglikelihood_true(model, args...)
Return the loglikelihood
of model
for args
This should generally be implemented by hand for every specific model
See also: logjoint_true
, logprior_true
— Functionlogjoint_true(model, args...)
Return the logjoint
of model
for args
Defaults to logprior_true(model, args...) + loglikelihood_true(model, args..)
This should generally be implemented by hand for every specific model
so that the returned value can be used as a ground-truth for testing things like:
- Validity of evaluation of
using a particular implementation ofAbstractVarInfo
. - Validity of a sampler when combined with DynamicPPL by running the sampler twice: once targeting ground-truth functions, e.g.
, and once targetingmodel
And more.
See also: logprior_true
, loglikelihood_true
And in the case where the model includes constrained variables, it can also be useful to define
— Functionlogprior_true_with_logabsdet_jacobian(model::Model, args...)
Return a tuple (args_unconstrained, logprior_unconstrained)
of model
for args...
Unlike logprior_true
, the returned logprior computation includes the log-absdet-jacobian adjustment, thus computing logprior for the unconstrained variables.
Note that args
are assumed be in the support of model
, while args_unconstrained
will be unconstrained.
See also: logprior_true
— Functionlogjoint_true_with_logabsdet_jacobian(model::Model, args...)
Return a tuple (args_unconstrained, logjoint)
of model
for args
Unlike logjoint_true
, the returned logjoint computation includes the log-absdet-jacobian adjustment, thus computing logjoint for the unconstrained variables.
Note that args
are assumed be in the support of model
, while args_unconstrained
will be unconstrained.
This should generally not be implemented directly, instead one should implement logprior_true_with_logabsdet_jacobian
for a given model
See also: logjoint_true
, logprior_true_with_logabsdet_jacobian
Finally, the following methods can also be of use:
— Functionvarnames(model::Model)
Return a collection of VarName
as they are expected to appear in the model.
Even though it is recommended to implement this by hand for a particular Model
, a default implementation using SimpleVarInfo{<:Dict}
is provided.
— Functionposterior_mean(model::Model)
Return a NamedTuple
compatible with varnames(model)
where the values represent the posterior mean under model
"Compatible" means that a varname
from varnames(model)
can be used to extract the corresponding value using get
, e.g. get(posterior_mean(model), varname)
— Functionsetup_varinfos(model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false)
Return a tuple of instances for different implementations of AbstractVarInfo
with each vi
, supposedly, satisfying vi[vn] == get(example_values, vn)
for vn
in varnames
If include_threadsafe
is true
, then the returned tuple will also include thread-safe versions of the varinfo instances.
— Functionupdate_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
Return instance similar to vi
but with vns
set to values from vals
— Functiontest_values(vi::AbstractVarInfo, vals::NamedTuple, vns)
Test that vi[vn]
corresponds to the correct value in vals
for every vn
in vns
Debugging Utilities
DynamicPPL provides a few methods for checking validity of a model-definition.
— Functioncheck_model([rng, ]model::Model; kwargs...)
Check that model
is valid, warning about any potential issues.
See check_model_and_trace
for more details on supported keword arguments and details of which types of checks are performed.
: Whether the model check succeeded.
— Functioncheck_model_and_trace([rng, ]model::Model; kwargs...)
Check that model
is valid, warning about any potential issues.
This will check the model for the following issues:
- Repeated usage of the same varname in a model.
- Incorrectly treating a variable as random rather than fixed, and vice versa.
: The random number generator to use when evaluating the model.model::Model
: The model to check.
Keyword Arguments
: The varinfo to use when evaluating the model. Default:VarInfo(model)
: The context to use when evaluating the model. Default:DefaultContext
: Whether to throw an error if the model check fails. Default:false
: Whether the model check succeeded.trace::Vector{Stmt}
: The trace of statements executed during the model check.
Correct model
julia> using StableRNGs
julia> rng = StableRNG(42);
julia> @model demo_correct() = x ~ Normal()
demo_correct (generic function with 2 methods)
julia> issuccess, trace = check_model_and_trace(rng, demo_correct());
julia> issuccess
julia> print(trace)
assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 (logprob = -1.14356)
julia> issuccess, trace = check_model_and_trace(rng, demo_correct() | (x = 1.0,));
julia> issuccess
julia> print(trace)
observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) (logprob = -1.41894)
Incorrect model
julia> @model function demo_incorrect()
# (×) Sampling `x` twice will lead to incorrect log-probabilities!
x ~ Normal()
x ~ Exponential()
demo_incorrect (generic function with 2 methods)
julia> issuccess, trace = check_model_and_trace(rng, demo_incorrect(); error_on_failure=true);
ERROR: varname x used multiple times in model
And some which might be useful to determine certain properties of the model based on the debug trace.
— Functionhas_static_constraints([rng, ]model::Model; num_evals=5, kwargs...)
Return true
if the model has static constraints, false
Note that this is a heuristic check based on sampling from the model multiple times and checking if the model is consistent across runs.
: The random number generator to use when evaluating the model.model::Model
: The model to check.
Keyword Arguments
: The number of evaluations to perform. Default:5
: Additional keyword arguments to pass tocheck_model_and_trace
For determining whether one might have type instabilities in the model, the following can be useful
— Functionmodel_warntype(model[, varinfo, context]; optimize=true)
Check the type stability of the model's evaluator, warning about any potential issues.
This simply calls @code_warntype
on the model's evaluator, filling in internal arguments where needed.
: The model to check.varinfo::AbstractVarInfo
: The varinfo to use when evaluating the model. Default:VarInfo(model)
: The context to use when evaluating the model. Default:DefaultContext
Keyword Arguments
: Whether to generate optimized code. Default:false
— Functionmodel_typed(model[, varinfo, context]; optimize=true)
Return the type inference for the model's evaluator.
This simply calls @code_typed
on the model's evaluator, filling in internal arguments where needed.
: The model to check.varinfo::AbstractVarInfo
: The varinfo to use when evaluating the model. Default:VarInfo(model)
: The context to use when evaluating the model. Default:DefaultContext
Keyword Arguments
: Whether to generate optimized code. Default:true
Interally, the type-checking methods make use of the following method for construction of the call with the argument types:
— Functiongen_evaluator_call_with_types(model[, varinfo, context])
Generate the evaluator call and the types of the arguments.
: The model whose evaluator is of interest.varinfo::AbstractVarInfo
: The varinfo to use when evaluating the model. Default:VarInfo(model)
: The context to use when evaluating the model. Default:DefaultContext
A 2-tuple with the following elements:
: This is eithermodel.f
, depending on whether the model has keyword arguments.argtypes::Type{<:Tuple}
: The types of the arguments for the evaluator.
Variable names
Names and possibly nested indices of variables are described with AbstractPPL.VarName
. They can be defined with AbstractPPL.@varname
. Please see the documentation of AbstractPPL.jl for further information.
Data Structures of Variables
DynamicPPL provides different data structures used in for storing samples and accumulation of the log-probabilities, all of which are subtypes of AbstractVarInfo
— TypeAbstractVarInfo
Abstract supertype for data structures that capture random variables when executing a probabilistic model and accumulate log densities such as the log likelihood or the log joint probability of the model.
See also: VarInfo
, SimpleVarInfo
But exactly how a AbstractVarInfo
stores this information can vary.
For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see the section on varinfo design for more on this), we have the following two methods:
— Functionuntyped_varinfo(model[, context, metadata])
Return an untyped varinfo object for the given model
and context
: The model for which to create the varinfo object.context::AbstractContext
: The context in which to evaluate the model. Default:SamplingContext()
: The metadata to use for the varinfo object. Default:Metadata()
— Functiontyped_varinfo(model[, context, metadata])
Return a typed varinfo object for the given model
, sampler
and context
This simply calls DynamicPPL.untyped_varinfo
and converts the resulting varinfo object to a typed varinfo object.
See also: DynamicPPL.untyped_varinfo
— Typestruct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo
A light wrapper over one or more instances of Metadata
. Let vi
be an instance of VarInfo
. If vi isa VarInfo{<:Metadata}
, then only one Metadata
instance is used for all the sybmols. VarInfo{<:Metadata}
is aliased UntypedVarInfo
. If vi isa VarInfo{<:NamedTuple}
, then vi.metadata
is a NamedTuple
that maps each symbol used on the LHS of ~
in the model to its Metadata
instance. The latter allows for the type specialization of vi
after the first sampling iteration when all the symbols have been observed. VarInfo{<:NamedTuple}
is aliased TypedVarInfo
Note: It is the user's responsibility to ensure that each "symbol" is visited at least once whenever the model is called, regardless of any stochastic branching. Each symbol refers to a Julia variable and can be a hierarchical array of many random variables, e.g. x[1] ~ ...
and x[2] ~ ...
both have the same symbol x
— TypeTypedVarInfo(vi::UntypedVarInfo)
This function finds all the unique sym
s from the instances of VarName{sym}
found in vi.metadata.vns
. It then extracts the metadata associated with each symbol from the global vi.metadata
field. Finally, a new VarInfo
is created with a new metadata
as a NamedTuple
mapping from symbols to type-stable Metadata
instances, one for each symbol.
One main characteristic of VarInfo
is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the main Turing documentation. The Transformations section below describes the methods used for this. In the specific case of VarInfo
, it keeps track of whether samples have been transformed by setting flags on them, using the following functions.
— Functionset_flag!(vi::VarInfo, vn::VarName, flag::String)
Set vn
's value for flag
to true
in vi
— Functionunset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false
Set vn
's value for flag
to false
in vi
Setting some flags for some VarInfo
types is not possible, and by default attempting to do so will error. If ignorable
is set to true
then this will silently be ignored instead.
— Functionis_flagged(vi::VarInfo, vn::VarName, flag::String)
Check whether vn
has a true value for flag
in vi
The following functions were used for sequential Monte Carlo methods.
— Functionget_num_produce(vi::VarInfo)
Return the num_produce
of vi
— Functionset_num_produce!(vi::VarInfo, n::Int)
Set the num_produce
field of vi
to n
— Functionincrement_num_produce!(vi::VarInfo)
Add 1 to num_produce
in vi
— Functionreset_num_produce!(vi::VarInfo)
Reset the value of num_produce
the log of the joint probability of the observed data and parameters sampled in vi
to 0.
— Functionsetorder!(vi::VarInfo, vn::VarName, index::Int)
Set the order
of vn
in vi
to index
, where order
is the number of observe statements run before sampling
— Functionset_retained_vns_del!(vi::VarInfo)
Set the "del"
flag of variables in vi
with order > vi.num_produce[]
to true
— Functionempty!(meta::Metadata)
Empty the fields of meta
This is useful when using a sampling algorithm that assumes an empty meta
, e.g. SMC
— Typestruct SimpleVarInfo{NT, T, C<:DynamicPPL.AbstractTransformation} <: AbstractVarInfo
A simple wrapper of the parameters with a logp
field for accumulation of the logdensity.
Currently only implemented for NT<:NamedTuple
and NT<:AbstractDict
: underlying representation of the realization representedlogp
: holds the accumulated log-probabilitytransformation
: represents whether it assumes variables to be transformed
The major differences between this and TypedVarInfo
does not require linearization.SimpleVarInfo
can use more efficient bijectors.SimpleVarInfo
is only type-stable ifNT<:NamedTuple
and either a) no indexing is used in tilde-statements, or b) the values have been specified with the correct shapes.
General usage
julia> using StableRNGs
julia> @model function demo()
m ~ Normal()
x = Vector{Float64}(undef, 2)
for i in eachindex(x)
x[i] ~ Normal()
return x
demo (generic function with 2 methods)
julia> m = demo();
julia> rng = StableRNG(42);
julia> ### Sampling ###
ctx = SamplingContext(rng, SampleFromPrior(), DefaultContext());
julia> # In the `NamedTuple` version we need to provide the place-holder values for
# the variables which are using "containers", e.g. `Array`.
# In this case, this means that we need to specify `x` but not `m`.
_, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo((x = ones(2), )), ctx);
julia> # (✓) Vroom, vroom! FAST!!!
julia> # We can also access arbitrary varnames pointing to `x`, e.g.
2-element Vector{Float64}:
julia> vi[@varname(x[1:2])]
2-element Vector{Float64}:
julia> # (×) If we don't provide the container...
_, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); vi
ERROR: type NamedTuple has no field x
julia> # If one does not know the varnames, we can use a `OrderedDict` instead.
_, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(OrderedDict()), ctx);
julia> # (✓) Sort of fast, but only possible at runtime.
julia> # In addtion, we can only access varnames as they appear in the model!
ERROR: KeyError: key x not found
julia> vi[@varname(x[1:2])]
ERROR: KeyError: key x[1:2] not found
Technically, it's possible to use any implementation of AbstractDict
in place of OrderedDict
, but OrderedDict
ensures that certain operations, e.g. linearization/flattening of the values in the varinfo, are consistent between evaluations. Hence OrderedDict
is the preferred implementation of AbstractDict
to use here.
You can also sample in transformed space:
julia> @model demo_constrained() = x ~ Exponential()
demo_constrained (generic function with 2 methods)
julia> m = demo_constrained();
julia> _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx);
julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞
julia> _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx);
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10];
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
julia> # And with `OrderedDict` of course!
_, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true), ctx);
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10];
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
Evaluation in transformed space of course also works:
julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true)
Transformed SimpleVarInfo((x = -1.0,), 0.0)
julia> # (✓) Positive probability mass on negative numbers!
getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))
julia> # While if we forget to indicate that it's transformed:
vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false)
SimpleVarInfo((x = -1.0,), 0.0)
julia> # (✓) No probability mass on negative numbers!
getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))
Using NamedTuple
as underlying storage.
julia> svi_nt = SimpleVarInfo((m = (a = [1.0], ), ));
julia> svi_nt[@varname(m)]
(a = [1.0],)
julia> svi_nt[@varname(m.a)]
1-element Vector{Float64}:
julia> svi_nt[@varname(m.a[1])]
julia> svi_nt[@varname(m.a[2])]
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
julia> svi_nt[@varname(m.b)]
ERROR: type NamedTuple has no field b
Using OrderedDict
as underlying storage.
julia> svi_dict = SimpleVarInfo(OrderedDict(@varname(m) => (a = [1.0], )));
julia> svi_dict[@varname(m)]
(a = [1.0],)
julia> svi_dict[@varname(m.a)]
1-element Vector{Float64}:
julia> svi_dict[@varname(m.a[1])]
julia> svi_dict[@varname(m.a[2])]
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
julia> svi_dict[@varname(m.b)]
ERROR: type NamedTuple has no field b
Common API
Accumulation of log-probabilities
— Functiongetlogp(vi::AbstractVarInfo)
Return the log of the joint probability of the observed data and parameters sampled in vi
— Functionsetlogp!!(vi::AbstractVarInfo, logp)
Set the log of the joint probability of the observed data and parameters sampled in vi
to logp
, mutating if it makes sense.
— Functionacclogp!!([context::AbstractContext, ]vi::AbstractVarInfo, logp)
Add logp
to the value of the log of the joint probability of the observed data and parameters sampled in vi
, mutating if it makes sense.
— Functionresetlogp!!(vi::AbstractVarInfo)
Reset the value of the log of the joint probability of the observed data and parameters sampled in vi
to 0, mutating if it makes sense.
Variables and their realizations
— Functionkeys(vi::AbstractVarInfo)
Return an iterator over all vns
in vi
— Functiongetindex(vi::AbstractVarInfo, vn::VarName[, dist::Distribution])
getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}[, dist::Distribution])
Return the current value(s) of vn
) in vi
in the support of its (their) distribution(s).
If dist
is specified, the value(s) will be massaged into the representation expected by dist
— Functionpush!!(vi::VarInfo, vn::VarName, r, dist::Distribution)
Push a new random variable vn
with a sampled value r
from a distribution dist
to the VarInfo
, mutating if it makes sense.
— Functionempty!!(vi::AbstractVarInfo)
Empty the fields of vi.metadata
and reset vi.logp[]
and vi.num_produce[]
to zeros.
This is useful when using a sampling algorithm that assumes an empty vi
, e.g. SMC
— Functionisempty(vi::AbstractVarInfo)
Return true if vi
is empty and false otherwise.
— Functiongetindex_internal(vi::AbstractVarInfo, vn::VarName)
getindex_internal(vi::AbstractVarInfo, vns::Vector{<:VarName})
Return the current value(s) of vn
) in vi
as represented internally in vi
See also: getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)
— Functionsetindex_internal!(vnv::VarNamedVector, val, i::Int)
Sets the i
th element of the internal storage vector, ignoring inactive entries.
setindex_internal!(vnv::VarNamedVector, val, vn::VarName[, transform])
Like setindex!
, but sets the values as they are stored internally in vnv
Optionally can set the transformation, such that transform(val)
is the original value of the variable. By default, the transform is the identity if creating a new entry in vnv
, or the existing transform if updating an existing entry.
— Functionupdate_internal!(vnv::VarNamedVector, vn::VarName, val::AbstractVector[, transform])
Update an existing entry for vn
in vnv
with the value val
Like setindex_internal!
, but errors if the key vn
doesn't exist.
should be a function that converts val
to the original representation. By default it's the same as the old transform for vn
— Functioninsert_internal!(vnv::VarNamedVector, val::AbstractVector, vn::VarName[, transform])
Add a variable with given value to vnv
Like setindex_internal!
, but errors if the key vn
already exists.
should be a function that converts val
to the original representation. By default it's identity
— Functionlength_internal(vnv::VarNamedVector)
Return the length of the internal storage vector of vnv
, ignoring inactive entries.
— Functionreset!(vnv::VarNamedVector, val, vn::VarName)
Reset the value of vn
in vnv
to val
This differs from setindex!
in that it will always change the transform of the variable to be the default vectorisation transform. This undoes any possible linking.
julia> using DynamicPPL: VarNamedVector, @varname, reset!
julia> vnv = VarNamedVector();
julia> vnv[@varname(x)] = reshape(1:9, (3, 3));
julia> setindex!(vnv, 2.0, @varname(x))
ERROR: An error occurred while assigning the value 2.0 to variable x. If you are changing the type or size of a variable you'll need to call reset!
julia> reset!(vnv, 2.0, @varname(x));
julia> vnv[@varname(x)]
— Functionupdate!(vnv::VarNamedVector, val, vn::VarName)
Update the value of vn
in vnv
to val
Like setindex!
, but errors if the key vn
doesn't exist.
— Functioninsert!(vnv::VarNamedVector, val, vn::VarName)
Add a variable with given value to vnv
Like setindex!
, but errors if the key vn
already exists.
— Functionloosen_types!!(vnv::VarNamedVector{K,V,TVN,TVal,TTrans}, ::Type{KNew}, ::Type{TransNew})
Loosen the types of vnv
to allow varname type KNew
and transformation type TransNew
If KNew
is a subtype of K
and TransNew
is a subtype of the element type of the TTrans
then this is a no-op and vnv
is returned as is. Otherwise a new VarNamedVector
is returned with the same data but more abstract types, so that variables of type KNew
and transformations of type TransNew
can be pushed to it. Some of the underlying storage is shared between vnv
and the return value, and thus mutating one may affect the other.
See also
julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal!
julia> vnv = VarNamedVector(@varname(x) => [1.0]);
julia> y_trans(x) = reshape(x, (2, 2));
julia> setindex_internal!(vnv, collect(1:4), @varname(y), y_trans)
ERROR: MethodError: Cannot `convert` an object of type
julia> vnv_loose = DynamicPPL.loosen_types!!(vnv, typeof(@varname(y)), typeof(y_trans));
julia> setindex_internal!(vnv_loose, collect(1:4), @varname(y), y_trans)
julia> vnv_loose[@varname(y)]
2×2 Matrix{Float64}:
1.0 3.0
2.0 4.0
— Functiontighten_types(vnv::VarNamedVector)
Return a copy of vnv
with the most concrete types possible.
For instance, if vnv
has its vector of transforms have eltype Any
, but all the transforms are actually identity transformations, this function will return a new VarNamedVector
with the transforms vector having eltype typeof(identity)
This is a lot like the reverse of loosen_types!!
, but with two notable differences: Unlike loosen_types!!
, this function does not mutate vnv
; it also changes not only the key and transform eltypes, but also the values eltype.
See also
julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal!
julia> vnv = VarNamedVector();
julia> setindex!(vnv, [23], @varname(x))
julia> eltype(vnv)
julia> vnv.transforms
1-element Vector{Any}:
identity (generic function with 1 method)
julia> vnv_tight = DynamicPPL.tighten_types(vnv);
julia> eltype(vnv_tight) == Int
julia> vnv_tight.transforms
1-element Vector{typeof(identity)}:
identity (generic function with 1 method)
— Functionvalues_as(varinfo[, Type])
Return the values/realizations in varinfo
as Type
, if implemented.
If no Type
is provided, return values as stored in varinfo
with NamedTuple
julia> data = (x = 1.0, m = [2.0]);
julia> values_as(SimpleVarInfo(data))
(x = 1.0, m = [2.0])
julia> values_as(SimpleVarInfo(data), NamedTuple)
(x = 1.0, m = [2.0])
julia> values_as(SimpleVarInfo(data), OrderedDict)
OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries:
x => 1.0
m => [2.0]
julia> values_as(SimpleVarInfo(data), Vector)
2-element Vector{Float64}:
with OrderedDict
julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]);
julia> values_as(SimpleVarInfo(data))
OrderedDict{Any, Any} with 2 entries:
x => 1.0
m => [2.0]
julia> values_as(SimpleVarInfo(data), NamedTuple)
(x = 1.0, m = [2.0])
julia> values_as(SimpleVarInfo(data), OrderedDict)
OrderedDict{Any, Any} with 2 entries:
x => 1.0
m => [2.0]
julia> values_as(SimpleVarInfo(data), Vector)
2-element Vector{Float64}:
julia> # Just use an example model to construct the `VarInfo` because we're lazy.
vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe());
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;
julia> # For the sake of brevity, let's just check the type.
md = values_as(vi); md.s isa Union{DynamicPPL.Metadata, DynamicPPL.VarNamedVector}
julia> values_as(vi, NamedTuple)
(s = 1.0, m = 2.0)
julia> values_as(vi, OrderedDict)
OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries:
s => 1.0
m => 2.0
julia> values_as(vi, Vector)
2-element Vector{Float64}:
julia> # Just use an example model to construct the `VarInfo` because we're lazy.
vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi);
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;
julia> # For the sake of brevity, let's just check the type.
values_as(vi) isa Union{DynamicPPL.Metadata, Vector}
julia> values_as(vi, NamedTuple)
(s = 1.0, m = 2.0)
julia> values_as(vi, OrderedDict)
OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries:
s => 1.0
m => 2.0
julia> values_as(vi, Vector)
2-element Vector{Real}:
— Typeabstract type AbstractTransformation
Represents a transformation to be used in link!!
and invlink!!
, amongst others.
A concrete implementation of this should implement the following methods:
: transforms theAbstractVarInfo
to the unconstrained space.invlink!!
: transforms theAbstractVarInfo
to the constrained space.
And potentially:
: hook to decide whether to transform before evaluating the model.
See also: link!!
, invlink!!
, maybe_invlink_before_eval!!
— Typestruct NoTransformation <: DynamicPPL.AbstractTransformation
Transformation which applies the identity function.
— Typestruct DynamicTransformation <: DynamicPPL.AbstractTransformation
Transformation which transforms the variables on a per-need-basis in the execution of a given Model
This is in constrast to StaticTransformation
which transforms all variables before the execution of a given Model
See also: StaticTransformation
— Typestruct StaticTransformation{F} <: DynamicPPL.AbstractTransformation
Transformation which transforms all variables before the execution of a given Model
This is done through the maybe_invlink_before_eval!!
See also: DynamicTransformation
, maybe_invlink_before_eval!!
: The function, assumed to implement theBijectors
interface, to be applied to the variables
— Functionistrans(vnv::VarNamedVector, vn::VarName)
Return a boolean for whether vn
is guaranteed to have been transformed so that its domain is all of Euclidean space.
istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}])
Return true
if vi
is working in unconstrained space, and false
if vi
is assuming realizations to be in support of the corresponding distributions.
If vns
is provided, then only check if this/these varname(s) are transformed.
Not all implementations of AbstractVarInfo
support transforming only a subset of the variables.
— Functionsettrans!!(vi::AbstractVarInfo, trans::Bool[, vn::VarName])
Return vi
with istrans(vi, vn)
evaluating to true
If vn
is not specified, then istrans(vi)
evaluates to true
for all variables.
— Functiontransformation(vi::AbstractVarInfo)
Return the AbstractTransformation
related to vi
— Functionlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)
Transform variables in vi
to their linked space without mutating vi
Either transform all variables, or only ones specified in vns
Use the transformation t
, or default_transformation(model, vi)
if one is not provided.
See also: default_transformation
, invlink
— Functioninvlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)
Transform variables in vi
to their constrained space without mutating vi
Either transform all variables, or only ones specified in vns
Use the (inverse of) transformation t
, or default_transformation(model, vi)
if one is not provided.
See also: default_transformation
, link
— Functionlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)
Transform variables in vi
to their linked space, mutating vi
if possible.
Either transform all variables, or only ones specified in vns
Use the transformation t
, or default_transformation(model, vi)
if one is not provided.
See also: default_transformation
, invlink!!
— Functioninvlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model)
Transform variables in vi
to their constrained space, mutating vi
if possible.
Either transform all variables, or only ones specified in vns
Use the (inverse of) transformation t
, or default_transformation(model, vi)
if one is not provided.
See also: default_transformation
, link!!
— Functiondefault_transformation(model::Model[, vi::AbstractVarInfo])
Return the AbstractTransformation
currently related to model
and, potentially, vi
— Functionlink_transform(dist)
Return the constrained-to-unconstrained bijector for distribution dist
By default, this is just Bijectors.bijector(dist)
Note that currently this is not used by Bijectors.logpdf_with_trans
, hence that needs to be overloaded separately if the intention is to change behavior of an existing distribution.
— Functioninvlink_transform(dist)
Return the unconstrained-to-constrained bijector for distribution dist
By default, this is just inverse(link_transform(dist))
Note that currently this is not used by Bijectors.logpdf_with_trans
, hence that needs to be overloaded separately if the intention is to change behavior of an existing distribution.
— Functionmaybe_invlink_before_eval!!([t::Transformation,] vi, model)
Return a possibly invlinked version of vi
This will be called prior to model
evaluation, allowing one to perform a single invlink!!
before evaluation rather than lazyily evaluating the transforms on as-we-need basis as is done with DynamicTransformation
See also: StaticTransformation
, DynamicTransformation
julia> using DynamicPPL, Distributions, Bijectors
julia> @model demo() = x ~ Normal()
demo (generic function with 2 methods)
julia> # By subtyping `Transform`, we inherit the `(inv)link!!`.
struct MyBijector <: Bijectors.Transform end
julia> # Define some dummy `inverse` which will be used in the `link!!` call.
Bijectors.inverse(f::MyBijector) = identity
julia> # We need to define `with_logabsdet_jacobian` for `MyBijector`
# (`identity` already has `with_logabsdet_jacobian` defined)
function Bijectors.with_logabsdet_jacobian(::MyBijector, x)
# Just using a large number of the logabsdet-jacobian term
# for demonstration purposes.
return (x, 1000)
julia> # Change the `default_transformation` for our model to be a
# `StaticTransformation` using `MyBijector`.
function DynamicPPL.default_transformation(::Model{typeof(demo)})
return DynamicPPL.StaticTransformation(MyBijector())
julia> model = demo();
julia> vi = SimpleVarInfo(x=1.0)
SimpleVarInfo((x = 1.0,), 0.0)
julia> # Uses the `inverse` of `MyBijector`, which we have defined as `identity`
vi_linked = link!!(vi, model)
Transformed SimpleVarInfo((x = 1.0,), 0.0)
julia> # Now performs a single `invlink!!` before model evaluation.
logjoint(model, vi_linked)
— Methodmerge(varinfo, other_varinfos...)
Merge varinfos into one, giving precedence to the right-most varinfo when sensible.
This is particularly useful when combined with subset(varinfo, vns)
See docstring of subset(varinfo, vns)
for examples.
— Functionsubset(varinfo::AbstractVarInfo, vns::AbstractVector{<:VarName})
Subset a varinfo
to only contain the variables vns
The ordering of the variables in the resulting varinfo
is not guaranteed to follow the ordering of the variables in varinfo
. Hence care must be taken, in particular when used in conjunction with other methods which uses the vector-representation of the varinfo
, e.g. getindex(varinfo, sampler)
julia> @model function demo()
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
x = Vector{Float64}(undef, 2)
x[1] ~ Normal(m, sqrt(s))
x[2] ~ Normal(m, sqrt(s))
demo (generic function with 2 methods)
julia> model = demo();
julia> varinfo = VarInfo(model);
julia> keys(varinfo)
4-element Vector{VarName}:
julia> for (i, vn) in enumerate(keys(varinfo))
varinfo[vn] = i
julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
julia> # Extract one with only `m`.
varinfo_subset1 = subset(varinfo, [@varname(m),]);
julia> keys(varinfo_subset1)
1-element Vector{VarName{:m, typeof(identity)}}:
julia> varinfo_subset1[@varname(m)]
julia> # Extract one with both `s` and `x[2]`.
varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]);
julia> keys(varinfo_subset2)
2-element Vector{VarName}:
julia> varinfo_subset2[[@varname(s), @varname(x[2])]]
2-element Vector{Float64}:
is particularly useful when combined with merge(varinfo::AbstractVarInfo)
julia> # Merge the two.
varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2);
julia> keys(varinfo_subset_merged)
3-element Vector{VarName}:
julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]]
3-element Vector{Float64}:
julia> # Merge the two with the original.
varinfo_merged = merge(varinfo, varinfo_subset_merged);
julia> keys(varinfo_merged)
4-element Vector{VarName}:
julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
This function is only type-stable when vns
contains only varnames with the same symbol. For exmaple, [@varname(m[1]), @varname(m[2])]
will be type-stable, but [@varname(m[1]), @varname(x)]
will not be.
— Functionunflatten(vi::AbstractVarInfo, x::AbstractVector)
Return a new instance of vi
with the values of x
assigned to the variables.
— Functionvarname_leaves(vn::VarName, val)
Return an iterator over all varnames that are represented by vn
on val
julia> using DynamicPPL: varname_leaves
julia> foreach(println, varname_leaves(@varname(x), rand(2)))
julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2)))
julia> x = (y = 1, z = [[2.0], [3.0]]);
julia> foreach(println, varname_leaves(@varname(x), x))
— Functionvarname_and_value_leaves(vn::VarName, val)
Return an iterator over all varname-value pairs that are represented by vn
on val
julia> using DynamicPPL: varname_and_value_leaves
julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
(x[1], 1)
(x[2], 2)
julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
(x[1:2][1], 1)
(x[1:2][2], 2)
julia> x = (y = 1, z = [[2.0], [3.0]]);
julia> foreach(println, varname_and_value_leaves(@varname(x), x))
(x.y, 1)
(x.z[1][1], 2.0)
(x.z[2][1], 3.0)
There are also some special handling for certain types:
julia> using LinearAlgebra
julia> x = reshape(1:4, 2, 2);
julia> # `LowerTriangular`
foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x)))
(x[1, 1], 1)
(x[2, 1], 2)
(x[2, 2], 4)
julia> # `UpperTriangular`
foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x)))
(x[1, 1], 1)
(x[1, 2], 3)
(x[2, 2], 4)
julia> # `Cholesky` with lower-triangular
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0)))
(x.L[1, 1], 1.0)
(x.L[2, 1], 0.0)
(x.L[2, 2], 1.0)
julia> # `Cholesky` with upper-triangular
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0)))
(x.U[1, 1], 1.0)
(x.U[1, 2], 0.0)
(x.U[2, 2], 1.0)
Return an iterator over all varname-value pairs that are represented by container
This is the same as varname_and_value_leaves(vn::VarName, x)
but over a container containing multiple varnames.
See also: varname_and_value_leaves(vn::VarName, x)
julia> using DynamicPPL: varname_and_value_leaves
julia> # With an `OrderedDict`
dict = OrderedDict(@varname(y) => 1, @varname(z) => [[2.0], [3.0]]);
julia> foreach(println, varname_and_value_leaves(dict))
(y, 1)
(z[1][1], 2.0)
(z[2][1], 3.0)
julia> # With a `NamedTuple`
nt = (y = 1, z = [[2.0], [3.0]]);
julia> foreach(println, varname_and_value_leaves(nt))
(y, 1)
(z[1][1], 2.0)
(z[2][1], 3.0)
Evaluation Contexts
Internally, both sampling and evaluation of log densities are performed with AbstractPPL.evaluate!!
— Functionevaluate!!(model::Model[, rng, varinfo, sampler, context])
Sample from the model
using the sampler
with random number generator rng
and the context
, and store the sample and log joint probability in varinfo
Returns both the return-value of the original model, and the resulting varinfo.
The method resets the log joint probability of varinfo
and increases the evaluation number of sampler
The behaviour of a model execution can be changed with evaluation contexts that are passed as additional argument to the model function. Contexts are subtypes of AbstractPPL.AbstractContext
— TypeSamplingContext(
Create a context that allows you to sample parameters with the sampler
when running the model. The context
determines how the returned log density is computed when running the model.
See also: DefaultContext
, LikelihoodContext
, PriorContext
— Typestruct DefaultContext <: AbstractContext end
The DefaultContext
is used by default to compute the log joint probability of the data and parameters when running the model.
— TypeLikelihoodContext <: AbstractContext
A leaf context resulting in the exclusion of prior terms when running the model.
— TypePriorContext <: AbstractContext
A leaf context resulting in the exclusion of likelihood terms when running the model.
— Typestruct MiniBatchContext{Tctx, T} <: AbstractContext
The MiniBatchContext
enables the computation of log(prior) + s * log(likelihood of a batch)
when running the model, where s
is the loglike_scalar
field, typically equal to the number of data points / batch size
. This is useful in batch-based stochastic gradient descent algorithms to be optimizing log(prior) + log(likelihood of all the data points)
in the expectation.
— TypePrefixContext{Prefix}(context)
Create a context that allows you to use the wrapped context
when running the model and adds the Prefix
to all parameters.
This context is useful in nested models to ensure that the names of the parameters are unique.
See also: to_submodel
— TypeConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext}
Model context that contains values that are to be conditioned on. The values can either be a NamedTuple mapping symbols to values, such as (a=1, b=2)
, or an AbstractDict mapping varnames to values (e.g. Dict(@varname(a) => 1, @varname(b) => 2)
). The former is more performant, but the latter must be used when there are varnames that cannot be represented as symbols, e.g. @varname(x[1])
In DynamicPPL two samplers are defined that are used to initialize unobserved random variables: SampleFromPrior
which samples from the prior distribution, and SampleFromUniform
which samples from a uniform distribution.
— TypeSampleFromPrior
Sampling algorithm that samples unobserved random variables from their prior distribution.
— TypeSampleFromUniform
Sampling algorithm that samples unobserved random variables from a uniform distribution.
Additionally, a generic sampler for inference is implemented.
— TypeSampler{T}
Generic sampler type for inference algorithms of type T
in DynamicPPL.
should implement the AbstractMCMC interface, and in particular AbstractMCMC.step
. A default implementation of the initial sampling step is provided that supports resuming sampling from a previous state and setting initial parameter values. It requires to overload loadstate
and initialstep
for loading previous states and actually performing the initial sampling step, respectively. Additionally, sometimes one might want to implement initialsampler
that specifies how the initial parameter values are sampled if they are not provided. By default, values are sampled from the prior.
The default implementation of Sampler
uses the following unexported functions.
— Functioninitialstep(rng, model, sampler, varinfo; kwargs...)
Perform the initial sampling step of the sampler
for the model
The varinfo
contains the initial samples, which can be provided by the user or sampled randomly.
— Functionloadstate(data)
Load sampler state from data
By default, data
is returned.
— Functioninitialsampler(sampler::Sampler)
Return the sampler that is used for generating the initial parameters when sampling with sampler
By default, it returns an instance of SampleFromPrior
Finally, to specify which varinfo type a Sampler
should use for a given Model
, this is specified by DynamicPPL.default_varinfo
and can thus be overloaded for each model
combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given model
and sampler
— Functiondefault_varinfo(rng, model, sampler[, context])
Return a default varinfo object for the given model
and sampler
: Random number generator.model::Model
: Model for which we want to create a varinfo object.sampler::AbstractSampler
: Sampler which will make use of the varinfo object.context::AbstractContext
: Context in which the model is evaluated.
: Default varinfo object for the givenmodel
There is also the experimental DynamicPPL.Experimental.determine_suitable_varinfo
, which uses static checking via JET.jl to determine whether one should use DynamicPPL.typed_varinfo
or DynamicPPL.untyped_varinfo
, depending on which supports the model:
— Functiondetermine_suitable_varinfo(model[, context]; only_ddpl::Bool=true)
Return a suitable varinfo for the given model
See also: DynamicPPL.Experimental.is_suitable_varinfo
For full functionality, this requires JET.jl to be loaded. If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo.
: The model for which to determine the varinfo.context
: The context to use for the model evaluation. Default:SamplingContext()
Keyword Arguments
: Iftrue
, only consider error reports within DynamicPPL.jl.
julia> using DynamicPPL.Experimental: determine_suitable_varinfo
julia> using JET: JET # needs to be loaded for full functionality
julia> @model function model_with_random_support()
x ~ Bernoulli()
if x
y ~ Normal()
z ~ Normal()
model_with_random_support (generic function with 2 methods)
julia> model = model_with_random_support();
julia> # Typed varinfo cannot handle this random support model properly
# as using a single execution of the model will not see all random variables.
# Hence, this this model requires untyped varinfo.
vi = determine_suitable_varinfo(model);
┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo.
└ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48
julia> vi isa typeof(DynamicPPL.untyped_varinfo(model))
julia> # In contrast, a simple model with no random support can be handled by typed varinfo.
@model model_with_static_support() = x ~ Normal()
model_with_static_support (generic function with 2 methods)
julia> vi = determine_suitable_varinfo(model_with_static_support());
julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support()))
— Functionis_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...)
Check if the model
supports evaluation using the provided context
and varinfo
Loading JET.jl is required before calling this function.
: The model to verify the support for.context
: The context to use for the model evaluation.varinfo
: The varinfo to verify the support for.
Keyword Arguments
: Iftrue
, only consider error reports occuring in the tilde pipeline. Default:true
if the model supports the varinfo, otherwisefalse
: The result ofreport_call
from JET.jl.
Model-Internal Functions
— Functiontilde_assume(context::SamplingContext, right, vn, vi)
Handle assumed variables, e.g., x ~ Normal()
(where x
does occur in the model inputs), accumulate the log probability, and return the sampled value with a context associated with a sampler.
Falls back to
tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
— Functiontilde_observe(context::SamplingContext, right, left, vi)
Handle observed constants with a context
associated with a sampler.
Falls back to tilde_observe(context.context, context.sampler, right, left, vi)