How PrefixContext and ConditionContext interact

PrefixContext

PrefixContext is a context that, as the name suggests, prefixes all variables inside a model with a given symbol. Thus, for example:

using DynamicPPL, Distributions

@model function f()
    x ~ Normal()
    return y ~ Normal()
end

@model function g()
    return a ~ to_submodel(f())
end
g (generic function with 2 methods)

inside the submodel f, the variables x and y become a.x and a.y respectively. This is easiest to observe by running the model:

vi = VarInfo(g())
keys(vi)
2-element Vector{VarName{:a}}:
 a.x
 a.y
Note

In this case, where to_submodel is called without any other arguments, the prefix to be used is automatically inferred from the name of the variable on the left-hand side of the tilde. We will return to the 'manual prefixing' case later.

The phrase 'becoming' a different variable is a little underspecified: it is useful to pinpoint the exact location where the prefixing occurs, which is tilde_assume. The method responsible for it is tilde_assume(::PrefixContext, right, vn, vi): this attaches the prefix in the context to the VarName argument, before recursively calling tilde_assume with the new prefixed VarName. This means that even though a statement x ~ dist still enters the tilde pipeline at the top level as x, if the model evaluation context contains a PrefixContext, any function from tilde_assume onwards will see a.x instead.

ConditionContext

ConditionContext is a context which stores values of variables that are to be conditioned on. These values may be stored as a Dict which maps VarNames to values, or alternatively as a NamedTuple. The latter only works correctly if all VarNames are 'basic', in that they have an identity optic (i.e., something like a.x or a[1] is forbidden). Because of this limitation, we will only use Dict in this example.

Note

If a ConditionContext with a NamedTuple encounters anything to do with a prefix, its internal NamedTuple is converted to a Dict anyway, so it is quite reasonable to ignore the NamedTuple case in this exposition.

One can inspect the conditioning values with, for example:

@model function d()
    x ~ Normal()
    return y ~ Normal()
end

cond_model = d() | (@varname(x) => 1.0)
cond_ctx = cond_model.context
ConditionContext(Dict(x => 1.0), DefaultContext())

There are several internal functions that are used to determine whether a variable is conditioned, and if so, what its value is.

DynamicPPL.hasconditioned_nested(cond_ctx, @varname(x))
true
DynamicPPL.getconditioned_nested(cond_ctx, @varname(x))
1.0

These functions are in turn used by the function DynamicPPL.contextual_isassumption, which is largely the same as hasconditioned_nested, but also checks whether the value is missing (in which case it isn't really conditioned).

DynamicPPL.contextual_isassumption(cond_ctx, @varname(x))
false
Note

Notice that (neglecting missing values) the return value of contextual_isassumption is the opposite of hasconditioned_nested, i.e. for a variable that is conditioned on, contextual_isassumption returns false.

If a variable x is conditioned on, then the effect of this is to set the value of x to the given value (while still including its contribution to the log probability density). Since x is no longer a random variable, if we were to evaluate the model, we would find only one key in the VarInfo:

keys(VarInfo(cond_model))
1-element Vector{VarName{:y, typeof(identity)}}:
 y

Joint behaviour: desiderata at the model level

When paired together, these two contexts have the potential to cause substantial confusion: PrefixContext modifies the variable names that are seen, which may cause them to be out of sync with the values contained inside the ConditionContext.

We begin by mentioning some high-level desiderata for their joint behaviour. Take these models, for example:

# We define a helper function to unwrap a layer of SamplingContext, to
# avoid cluttering the print statements.
unwrap_sampling_context(ctx::DynamicPPL.SamplingContext) = ctx.context
unwrap_sampling_context(ctx::DynamicPPL.AbstractContext) = ctx
@model function inner()
    println("inner context: $(unwrap_sampling_context(__context__))")
    x ~ Normal()
    return y ~ Normal()
end

@model function outer()
    println("outer context: $(unwrap_sampling_context(__context__))")
    return a ~ to_submodel(inner())
end

# 'Outer conditioning'
with_outer_cond = outer() | (@varname(a.x) => 1.0)

# 'Inner conditioning'
inner_cond = inner() | (@varname(x) => 1.0)
@model function outer2()
    println("outer context: $(unwrap_sampling_context(__context__))")
    return a ~ to_submodel(inner_cond)
end
with_inner_cond = outer2()
Model{typeof(Main.outer2), (), (), (), Tuple{}, Tuple{}, DefaultContext}(Main.outer2, NamedTuple(), NamedTuple(), DefaultContext())

We want that:

  1. keys(VarInfo(outer())) should return [a.x, a.y];
  2. keys(VarInfo(with_outer_cond)) should return [a.y];
  3. keys(VarInfo(with_inner_cond)) should return [a.y],

In other words, we can condition submodels either from the outside (point (2)) or from the inside (point (3)), and the variable name we use to specify the conditioning should match the level at which we perform the conditioning.

This is an incredibly salient point because it means that submodels can be treated as individual, opaque objects, and we can condition them without needing to know what it will be prefixed with, or the context in which that submodel is being used. For example, this means we can reuse inner_cond in another model with a different prefix, and it will still have its inner x value be conditioned, despite the prefix differing.

Info

In the current version of DynamicPPL, these criteria are all fulfilled. However, this was not the case in the past: in particular, point (3) was not fulfilled, and users had to condition the internal submodel with the prefixes that were used outside. (See this GitHub issue for more information; this issue was the direct motivation for this documentation page.)

Desiderata at the context level

The above section describes how we expect conditioning and prefixing to behave from a user's perpective. We now turn to the question of how we implement this in terms of DynamicPPL contexts. We do not specify the implementation details here, but we will sketch out something resembling an API that will allow us to achieve the target behaviour.

Point (1) does not involve any conditioning, only prefixing; it is therefore already satisfied by virtue of the tilde_assume method shown above.

Points (2) and (3) are more tricky. As the reader may surmise, the difference between them is the order in which the contexts are stacked.

For the outer conditioning case (point (2)), the ConditionContext will contain a VarName that is already prefixed. When we enter the inner submodel, this ConditionContext has to be passed down and somehow combined with the PrefixContext that is created when we enter the submodel. We make the claim here that the best way to do this is to nest the PrefixContext inside the ConditionContext. This is indeed what happens, as can be demonstrated by running the model.

with_outer_cond();
nothing;
outer context: ConditionContext(Dict(a.x => 1.0), DefaultContext())
inner context: ConditionContext(Dict(a.x => 1.0), PrefixContext{VarName{:a, typeof(identity)}, DefaultContext}(a, DefaultContext()))
Info

The ; nothing at the end is purely to circumvent a Documenter.jl quirk where stdout is only shown if the return value of the final statement is nothing. If these documentation pages are moved to Quarto, it will be possible to remove this.

For the inner conditioning case (point (3)), the outer model is not run with any special context. The inner model will itself contain a ConditionContext will contain a VarName that is not prefixed. When we run the model, this ConditionContext should be then nested inside a PrefixContext to form the final evaluation context. Again, we can run the model to see this in action:

with_inner_cond();
nothing;
outer context: DefaultContext()
inner context: PrefixContext{VarName{:a, typeof(identity)}, ConditionContext{Dict{VarName{:x, typeof(identity)}, Float64}, DefaultContext}}(a, ConditionContext(Dict(x => 1.0), DefaultContext()))

Putting all of the information so far together, what it means is that if we have these two inner contexts (taken from above):

using DynamicPPL: PrefixContext, ConditionContext, DefaultContext

inner_ctx_with_outer_cond = ConditionContext(
    Dict(@varname(a.x) => 1.0), PrefixContext(@varname(a))
)
inner_ctx_with_inner_cond = PrefixContext(
    @varname(a), ConditionContext(Dict(@varname(x) => 1.0))
)
PrefixContext{VarName{:a, typeof(identity)}, ConditionContext{Dict{VarName{:x, typeof(identity)}, Float64}, DefaultContext}}(a, ConditionContext(Dict(x => 1.0), DefaultContext()))

then we want both of these to be true (and thankfully, they are!):

DynamicPPL.hasconditioned_nested(inner_ctx_with_outer_cond, @varname(a.x))
true
DynamicPPL.hasconditioned_nested(inner_ctx_with_inner_cond, @varname(a.x))
true

This allows us to finally specify our task as follows:

(1) Given the correct arguments, we need to make sure that hasconditioned_nested and getconditioned_nested behave correctly.

(2) We need to make sure that both the correct arguments are supplied. In order to do so:

  • (2a) We need to make sure that when evaluating a submodel, the context stack is arranged such that PrefixContext is applied inside the parent model's context, but outside the submodel's own context.

  • (2b) We also need to make sure that the VarName passed to it is prefixed correctly.

How do we do it?

(1) hasconditioned_nested and getconditioned_nested accomplish this by first 'collapsing' the context stack, i.e. they go through the context stack, remove all PrefixContexts, and apply those prefixes to any conditioned variables below it in the stack. Once the PrefixContexts have been removed, one can then iterate through the context stack and check if any of the ConditionContexts contain the variable, or get the value itself. For more details the reader is encouraged to read the source code.

(2a) We ensure that the context stack is correctly arranged by relying on the behaviour of make_evaluate_args_and_kwargs. This function is called whenever a model (which itself contains a context) is evaluated with a separate ('external') context, and makes sure to arrange both of these contexts such that the model's context is nested inside the external context. Thus, as long as prefixing is implemented by applying a PrefixContext on the outermost layer of the inner model context, this will be correctly combined with an external context to give the behaviour seen above.

(2b) At first glance, it seems like tilde_assume can take care of the VarName prefixing for us (as described in the first section). However, this is not actually the case: contextual_isassumption, which is the function that calls hasconditioned_nested, is much higher in the call stack than tilde_assume is. So, we need to explicitly prefix it before passing it to contextual_isassumption. This is done inside the @model macro, or technically, its subsidiary function isassumption.

Nested submodels

Just in case the above wasn't complicated enough, we need to also be very careful when dealing with nested submodels, which have multiple layers of PrefixContexts which may be interspersed with ConditionContexts. For example, in this series of nested submodels,

@model function charlie()
    x ~ Normal()
    y ~ Normal()
    return z ~ Normal()
end
@model function bravo()
    return b ~ to_submodel(charlie() | (@varname(x) => 1.0))
end
@model function alpha()
    return a ~ to_submodel(bravo() | (@varname(b.y) => 1.0))
end
alpha (generic function with 2 methods)

we expect that the only variable to be sampled should be z inside charlie, or rather, a.b.z once it has been through the prefixes.

keys(VarInfo(alpha()))
1-element Vector{VarName{:a, ComposedFunction{Accessors.PropertyLens{:z}, Accessors.PropertyLens{:b}}}}:
 a.b.z

The general strategy that we adopt is similar to above. Following the principle that PrefixContext should be nested inside the outer context, but outside the inner submodel's context, we can infer that the correct context inside charlie should be:

big_ctx = PrefixContext(
    @varname(a),
    ConditionContext(
        Dict(@varname(b.y) => 1.0),
        PrefixContext(@varname(b), ConditionContext(Dict(@varname(x) => 1.0))),
    ),
)
PrefixContext{VarName{:a, typeof(identity)}, ConditionContext{Dict{VarName{:b, Accessors.PropertyLens{:y}}, Float64}, PrefixContext{VarName{:b, typeof(identity)}, ConditionContext{Dict{VarName{:x, typeof(identity)}, Float64}, DefaultContext}}}}(a, ConditionContext(Dict(b.y => 1.0), PrefixContext{VarName{:b, typeof(identity)}, ConditionContext{Dict{VarName{:x, typeof(identity)}, Float64}, DefaultContext}}(b, ConditionContext(Dict(x => 1.0), DefaultContext()))))

We need several things to work correctly here: we need the VarName prefixing to behave correctly, and then we need to implement hasconditioned_nested and getconditioned_nested on the resulting prefixed VarName. It turns out that the prefixing itself is enough to illustrate the most important point in this section, namely, the need to traverse the context stack in a different direction to what most of DynamicPPL does.

Let's work with a function called myprefix(::AbstractContext, ::VarName) (to avoid confusion with any existing DynamicPPL function). We should like myprefix(big_ctx, @varname(x)) to return @varname(a.b.x). Consider the following naive implementation, which mirrors a lot of code in the tilde-pipeline:

using DynamicPPL: NodeTrait, IsLeaf, IsParent, childcontext, AbstractContext
using AbstractPPL: AbstractPPL

function myprefix(ctx::DynamicPPL.AbstractContext, vn::VarName)
    return myprefix(NodeTrait(ctx), ctx, vn)
end
function myprefix(::IsLeaf, ::AbstractContext, vn::VarName)
    return vn
end
function myprefix(::IsParent, ctx::AbstractContext, vn::VarName)
    return myprefix(childcontext(ctx), vn)
end
function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName)
    # The functionality to actually manipulate the VarNames is in AbstractPPL
    new_vn = AbstractPPL.prefix(vn, ctx.vn_prefix)
    # Then pass to the child context
    return myprefix(childcontext(ctx), new_vn)
end

myprefix(big_ctx, @varname(x))
b.a.x

This implementation clearly is not correct, because it applies the inner PrefixContext before the outer one.

The right way to implement myprefix is to, essentially, reverse the order of two lines above:

function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName)
    # Pass to the child context first
    new_vn = myprefix(childcontext(ctx), vn)
    # Then apply this context's prefix
    return AbstractPPL.prefix(new_vn, ctx.vn_prefix)
end

myprefix(big_ctx, @varname(x))
a.b.x

This is a much better result! The implementation of related functions such as hasconditioned_nested and getconditioned_nested, under the hood, use a similar recursion scheme, so you will find that this is a common pattern when reading the source code of various prefixing-related functions. When editing this code, it is worth being mindful of this as a potential source of incorrectness.

Info

If you have encountered left and right folds, the above discussion illustrates the difference between them: the wrong implementation of myprefix uses a left fold (which collects prefixes in the opposite order from which they are encountered), while the correct implementation uses a right fold.

Loose ends 1: Manual prefixing

Sometimes users may want to manually prefix a model, for example:

@model function inner_manual()
    x ~ Normal()
    return y ~ Normal()
end

@model function outer_manual()
    return _unused ~ to_submodel(prefix(inner_manual(), :a), false)
end
outer_manual (generic function with 2 methods)

In this case, the VarName on the left-hand side of the tilde is not used, and the prefix is instead specified using the prefix function.

The way to deal with this follows on from the previous discussion. Specifically, we said that:

[...] as long as prefixing is implemented by applying a PrefixContext on the outermost layer of the inner model context, this will be correctly combined [...]

When automatic prefixing is used, this application of PrefixContext occurs inside the tilde_assume!! method. In the manual prefixing case, we need to make sure that prefix(submodel::Model, ::Symbol) does the same thing, i.e. it inserts a PrefixContext at the outermost layer of submodel's context. We can see that this is precisely what happens:

@model f() = x ~ Normal()

model = f()
prefixed_model = prefix(model, :a)

(model.context, prefixed_model.context)
(DefaultContext(), PrefixContext{VarName{:a, typeof(identity)}, DefaultContext}(a, DefaultContext()))

Loose ends 2: FixedContext

Finally, note that all of the above also applies to the interaction between PrefixContext and FixedContext, except that the functions have different names. (FixedContext behaves the same way as ConditionContext, except that unlike conditioned variables, fixed variables do not contribute to the log probability density.) This generally results in a large amount of code duplication, but the concepts that underlie both contexts are exactly the same.