Conditioning and fixing

DynamicPPL allows you to first define a model, and then modify it by either conditioning on observed data, or fixing variables to specific values. This is useful for defining models once and then using them in different ways.

As an example, one could define a linear regression model as follows:

using DynamicPPL, Distributions

@model function linear_regression(x)
    m ~ Normal(0, 1)
    c ~ Normal(0, 1)

    y = Vector{Float64}(undef, length(x))
    for i in eachindex(x)
        y[i] ~ Normal(m * x[i] + c, 1.0)
    end
end
linear_regression (generic function with 2 methods)

This model right now does not have any observed data: the variable y is not part of the model arguments, nor is it conditioned on any values, so all the y[i]'s are treated as latent variables.

Why do we need to define `y` in the model?

The definition of y in the model is needed so that there is somewhere to assign y[i] to after the tilde-statement runs. If we did not define y, we would get an error when trying to call setindex! on an undefined variable.

The fact that we defined y within the model does not change this: all variables on the left-hand side of a tilde-statement are treated as latent variables unless explicitly conditioned on or provided as an argument to the model function.

Let's create some synthetic data to work with:

true_m, true_c = 5.0, 3.0

x = 0:0.1:0.5
y_data = true_m .* x .+ true_c .+ randn(length(x))
6-element Vector{Float64}:
 2.799745365266664
 4.271308654014712
 4.406464790937631
 3.7741411970694068
 5.873222504989637
 5.247132230509862

If we run the model before conditioning on y, we will find that all of m, c, and y are drawn from the prior distribution.

model = linear_regression(x)

# Here, `rand(model())` samples from the prior distribution and returns a
# VarNamedTuple of latent variables.
rand(model)
VarNamedTuple
├─ m => -1.3848762961429908
├─ c => -0.12351031975920958
└─ y => PartialArray size=(6,) data::Vector{Float64}
        ├─ (1,) => 2.7006456100976397
        ├─ (2,) => 0.3537584679840871
        ├─ (3,) => 0.1107142638133623
        ├─ (4,) => -0.7403054969703036
        ├─ (5,) => -0.9196711088027121
        └─ (6,) => -1.7143063070187305

We could, for example, do this many times, and compute the prior mean of y. This is analogous to using Turing's Prior() sampler.

vnts = [rand(model) for _ in 1:1000]
mean(vnt[@varname(y)] for vnt in vnts)
6-element Vector{Float64}:
  0.008110359419693567
  0.024032324965422645
 -0.05334204269371158
 -0.08267518096447084
 -0.017637504411774488
 -0.016478102816500032

This is useful for prior predictive checks, for example.

Conditioning

To condition the model on observed data, we can use the condition function, or its alias |. The most robust way of conditioning is to provide a VarNamedTuple that holds the values to condition on.

# Construct a `VarNamedTuple` that holds the conditioning values.
observations = @vnt begin
    y := y_data
end

# Equivalently: conditioned_model = condition(model, observations).
cond_model = model | observations
Model{typeof(Main.linear_regression), (:x,), (), (), Tuple{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}, Tuple{}, DynamicPPL.CondFixContext{DynamicPPL.Condition, VarNamedTuple{(:y,), Tuple{Vector{Float64}}}, DefaultContext}, false}(Main.linear_regression, (x = 0.0:0.1:0.5,), NamedTuple(), CondFixContext{DynamicPPL.Condition}(VarNamedTuple(y = [2.799745365266664, 4.271308654014712, 4.406464790937631, 3.7741411970694068, 5.873222504989637, 5.247132230509862],), DefaultContext()))

We can inspect the values that have been conditioned on, using the conditioned function:

conditioned(cond_model)
VarNamedTuple
└─ y => [2.799745365266664, 4.271308654014712, 4.406464790937631, 3.7741411970694068, 5.873222504989637, 5.247132230509862]

If we were to run this model, we would now find that y is an observed variable, and thus it is not sampled:

parameters = rand(cond_model)
VarNamedTuple
├─ m => 0.47345905968658164
└─ c => 0.07070001767720674

We can't directly draw from the posterior using DynamicPPL (rand still draws from the prior). However, since this is now an observed variable, the log-likelihood associated with the newly provided y will be computed:

loglikelihood(cond_model, parameters)
-61.15141359522337

and this quantity can be used by MCMC algorithms to draw samples from the posterior distribution.

Fixing

Fixing is exactly the same as conditioning, except that instead of incrementing the log-likelihood, there is no log-probability contribution from the fixed variables.

In essence, fixing a variable x ~ dist to a value x_val is equivalent to replacing the statement with x = x_val, which removes it from the model entirely.

We can illustrate this by fixing the intercept c to its true value:

# Construct a `VarNamedTuple` that holds the fixed values.
fix_values = VarNamedTuple(; c=true_c)

fixed_model = fix(model, fix_values)
Model{typeof(Main.linear_regression), (:x,), (), (), Tuple{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}, Tuple{}, DynamicPPL.CondFixContext{DynamicPPL.Fix, VarNamedTuple{(:c,), Tuple{Float64}}, DefaultContext}, false}(Main.linear_regression, (x = 0.0:0.1:0.5,), NamedTuple(), CondFixContext{DynamicPPL.Fix}(VarNamedTuple(c = 3.0,), DefaultContext()))

and sampling from the prior again:

parameters_fixed = rand(fixed_model)
VarNamedTuple
├─ m => 0.05418681215394027
└─ y => PartialArray size=(6,) data::Vector{Float64}
        ├─ (1,) => 4.380417413028404
        ├─ (2,) => 3.868542841384008
        ├─ (3,) => 3.6692336140942214
        ├─ (4,) => 3.4877977680023005
        ├─ (5,) => 3.828091498163319
        └─ (6,) => -0.28593017865438464

If we were to repeat this many times, we would find that y is drawn from its prior, but because c is fixed, the samples will reflect that:

mean(vnt[@varname(y)] for vnt in [rand(fixed_model) for _ in 1:1000])
6-element Vector{Float64}:
 3.0134881171679293
 2.9734661512712335
 2.9994429091287476
 2.9565604386060533
 3.0099240058497565
 3.0005707249571727

Supplying parameters to condition or fix on

In the above examples we have provided the conditioning and fixing values as VarNamedTuples. Internally, DynamicPPL stores the values as VarNamedTuples, and it is strongly recommended that you construct them this way.

For convenience, both condition and fix also accept a variety of different input formats:

# NamedTuple
model | (; y=y_data)

# AbstractDict{VarName}
model | Dict(@varname(y) => y_data)

# Pair
model | (@varname(y) => y_data)

Note, however, that these alternative input formats are not necessarily rich enough to capture all the necessary information. We recommend using VarNamedTuples directly in all cases.

For example, if you only wanted to condition y[1] but not the other y[i]'s, you cannot specify this via a NamedTuple, since NamedTuples require Symbols as keys.

You can easily specify this via VarNamedTuple and its helper macro @vnt:

vnt = @vnt begin
    y[1] := y_data[1]
end
VarNamedTuple
└─ y => PartialArray size=(1,) data::DynamicPPL.VarNamedTuples.GrowableArray{Float64, 1}
        └─ (1,) => 2.799745365266664

Note that in this case since the VarNamedTuple has no knowledge of the length or shape of y, DynamicPPL will assume that y is a Base.Vector of unknown length (hence the GrowableArray above).

This will work fine as long as y is indeed a Base.Vector. However, if you want to avoid this, you should provide the full shape of y when defining the VarNamedTuple:

vnt = @vnt begin
    @template y = y_data
    y[1] := y_data[1]
end
VarNamedTuple
└─ y => PartialArray size=(6,) data::Vector{Float64}
        └─ (1,) => 2.799745365266664

Now, the variable y is known to have the same shape and type as y_data.

Warning

If you use custom array types in DynamicPPL that have different indexing semantics from Base.Array, then the templating shown here becomes mandatory. For example, OffsetArrays may behave incorrectly if templates are not supplied.

If we run the model again, we should find that y[1] is no longer sampled:

cond_model_partial = model | vnt
rand(cond_model_partial)
VarNamedTuple
├─ m => -0.6083996321361124
├─ c => -2.1433249513738137
└─ y => PartialArray size=(6,) data::Vector{Float64}
        ├─ (2,) => -0.6477296062010258
        ├─ (3,) => -1.5613161006872565
        ├─ (4,) => -3.1757357597743843
        ├─ (5,) => -2.0819186126729123
        └─ (6,) => -2.5751273181840912

Missing data

Warning

The details in this section are tied closely to internal DynamicPPL details and we recommend that you use the above methods on conditioning on subsets of data. This is merely documented for completeness, and to avoid confusion since these details have been discussed in previous issues and Discourse threads.

Sometimes, in order to condition on a part of y, you can in fact condition on a vector y that has some of its entries missing.

For this to work, it is mandatory that each y[i] is individually on the left-hand side of a tilde-statement, as in the linear regression example above. That means that you can write

for i in eachindex(x)
    y[i] ~ Normal(m * x[i] + c, 1.0)
end

but not

y ~ MvNormal(m .* x .+ c, I)

The reason why this works is if DynamicPPL finds a conditioned value of missing, it will treat the variable as not actually being conditioned. When each y[i] is individually on the left-hand side of a tilde-statement, this means that DynamicPPL can identify individual y[i]'s that are missing, and treat them as latent variables.

vnt = @vnt begin
    y := [missing, missing, 1.0, missing, 2.0, missing]
end
cond_model_missing = model | vnt

rand(cond_model_missing)
VarNamedTuple
├─ m => 1.4798133273751382
├─ c => -0.8106314183626978
└─ y => PartialArray size=(6,) data::Vector{Float64}
        ├─ (1,) => 0.05978774118367214
        ├─ (2,) => 0.8185321297058805
        ├─ (4,) => -1.2504662757191132
        └─ (6,) => -0.29220217716683616

On the other hand, if the entire y vector is on the left-hand side of a single tilde-statement, DynamicPPL cannot separate it into its missing and non-missing parts.