Transform strategies

Often it is useful to evaluate the log-probability of a model in a different space to the original one that it is defined in.

Note

The main Turing documentation site has a more detailed introduction to variable transformations in MCMC sampling. This page only describes the implementation in DynamicPPL.

Consider the following model:

using DynamicPPL, Distributions

@model function f()
    x ~ LogNormal()
    y ~ LogNormal()
    return (x, y)
end
f (generic function with 2 methods)

There are several ways in which we might want to evaluate this model:

  • In the original ('untransformed') space: we provide values for x and y which are both positive, and we evaluate the log-probability directly. This corresponds directly to

    x, y = 1.5, 2.0
    logp = logpdf(LogNormal(), x) + logpdf(LogNormal(), y)
    -3.2589168389831387
  • In unconstrained ('transformed') space: we provide values for x and y in unconstrained space (i.e. they are real numbers), and evaluate the log-probability in transformed space. This corresponds to something like:

    x, y = 1.5, 2.0
    
    # To calculate the correct log-probability, we need to transform back to the
    # original space, but also account for the log-absolute-determinant of the
    # transformation Jacobian.
    trf_x, trf_y = log(x), log(y)
    logdetJx = -trf_x
    logdetJy = -trf_y
    
    logp = logpdf(LogNormal(), x) + logpdf(LogNormal(), y) - logdetJx - logdetJy
    -2.160304550315029
  • We might also want to have a mix of transformed and untransformed variables, for example, x in untransformed space and y in transformed space. This could be useful for example when using Gibbs sampling, or Metropolis–Hastings with different proposal distributions for different variables.

AbstractTransformStrategy

DynamicPPL allows you to specify which variables you want to evaluate in transformed space using transform strategies. All transform strategies are subtypes of AbstractTransformStrategy. Currently, DynamicPPL provides the transform strategies LinkAll, UnlinkAll, LinkSome, and UnlinkSome. Their meanings should be fairly self-explanatory; here is a brief demonstration:

params = @vnt begin
    x := 1.5
    y := 2.0
end
_, vi_unlinked = init!!(f(), OnlyAccsVarInfo(), InitFromParams(params), UnlinkAll())
vi_unlinked.accs
AccumulatorTuple with 3 accumulators
├─ LogPrior => LogPriorAccumulator(-3.2589168389831387)
├─ LogJacobian => LogJacobianAccumulator(0.0)
└─ LogLikelihood => LogLikelihoodAccumulator(0.0)
_, vi_linked = init!!(f(), OnlyAccsVarInfo(), InitFromParams(params), LinkAll())
vi_linked.accs
AccumulatorTuple with 3 accumulators
├─ LogPrior => LogPriorAccumulator(-3.2589168389831387)
├─ LogJacobian => LogJacobianAccumulator(-1.0986122886681096)
└─ LogLikelihood => LogLikelihoodAccumulator(0.0)
Initialisation strategy does not determine log-Jacobian

In the above examples, we used InitFromParams to provide variable values. InitFromParams is an initialisation strategy and when given a VarNamedTuple of values as we did above, it always interprets those values as being in untransformed space.

This does not however mean that the log-Jacobian is disregarded! As we see in the second example above, when using LinkAll(), the log-Jacobian is still applied even though the values were provided in untransformed space. The transform strategy is what determines whether the log-Jacobian is applied or not when evaluating the log-probability. One could think of the transform strategy as being a re-interpretation of the value provided by the initialisation strategy.

This frees up the initialisation strategy to return whatever kind of AbstractTransformedValue is most convenient for it.

Making your own transform strategy

The only requirement for a subtype of an AbstractTransformStrategy is that it must implement target_transform(::AbstractTransformStrategy, vn::VarName), where vn is the variable on the left-hand side of a tilde-statement.

target_transform must in turn return an AbstractTransform specifying whether the variable should be transformed or not.

For example, the following would cause x to be transformed but not y.

struct LookupTransformsInVNT{V<:VarNamedTuple} <: AbstractTransformStrategy
    transforms::V
end

function DynamicPPL.target_transform(l::LookupTransformsInVNT, vn::VarName)
    return l.transforms[vn]
end

link_x_only = LookupTransformsInVNT(@vnt begin
    x := DynamicLink()
    y := Unlink()
end)

_, vi_link_x_only = init!!(f(), OnlyAccsVarInfo(), InitFromParams(params), link_x_only)
vi_link_x_only.accs
AccumulatorTuple with 3 accumulators
├─ LogPrior => LogPriorAccumulator(-3.2589168389831387)
├─ LogJacobian => LogJacobianAccumulator(-0.4054651081081644)
└─ LogLikelihood => LogLikelihoodAccumulator(0.0)

Looking at the two transform types used above, Unlink() is probably more intuitive: it just means 'do not interpret this variable as being in transformed space'. However, DynamicLink() is a bit more subtle.

In particular, for DynamicLink(), the actual transformation used is obtained at runtime from the distribution on the right-hand side of the tilde-statement, using:

DynamicPPL.from_linked_vec_transform(LogNormal())
DynamicPPL.UnwrapSingletonTransform{Tuple{}}(()) ∘ (Base.Fix1{typeof(broadcast), typeof(exp)}(broadcast, exp) ∘ DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}((1,), ()))

(which ultimately calls functions from Bijectors.jl). This means that the transformation is recalculated on every model evaluation.

One might question why this is necessary: for example, in this simple model, we know that x and y are always LogNormal, so why not just store use the log-transform directly?

The answer is to do with distributions whose support (and hence transformation) can vary. For example, consider:

@model function g()
    x ~ Normal()
    return y ~ truncated(Normal(); lower=x)
end
g (generic function with 2 methods)

In this example, the support of y depends on what value x takes, and that can vary from one evaluation to the next. Consequently, the transformation used for y must be determined at runtime; if we cache a fixed transformation, it is possible that this transformation will be invalid (e.g. by mapping unconstrained values to values outside the support of y).

For correctness, DynamicPPL therefore always prefers to determine the transformation at runtime when using DynamicLink(). This behaviour is encoded in (for example) Turing's HMC samplers, which use LinkAll() as the default transform strategy, and hence every VarName will have a target_transform of DynamicLink().

Fixed transformations

For some models, it may be known that the support of a variable does not change, and that the transformations should be fixed. This allows us to avoid the overhead of recomputing the transformation at every model evaluation.

This is currently not implemented, but there is a plan for it; see this DynamicPPL issue for details.

Why not let init() determine the transform?

Warning

This section is mainly for developers and advanced users interested in the design decisions behind DynamicPPL; it has no real implications for everyday usage.

An alternative to having an explicit link strategy would be to simply allow the initialisation strategy to determine whether variables are transformed or not. In this world, if init(rng, vn, dist, strategy) returned a transformed value, then we could treat it as being in transformed space, and vice versa.

The reason why we do this is to allow more flexibility in how models are evaluated, which can in turn save us from having to rerun the model multiple times.

Consider, for example, DynamicPPL.InitFromUniform. This is an initialisation strategy which samples uniformly from [-2, 2] in transformed space, and is used by Turing's HMC samplers to generate initial values. This used to be a standard workflow in Turing:

using Random
_, vi = init!!(Xoshiro(468), f(), VarInfo(), InitFromUniform())

# We'll run this later.
# vi_linked = link!!(vi, f())
((0.1584005226486094, 0.18462758115248692), VarInfo{false, VarNamedTuple{(:x, :y), Tuple{VectorValue{Vector{Float64}, DynamicPPL.UnwrapSingletonTransform{Tuple{Int64}}, Tuple{}}, VectorValue{Vector{Float64}, DynamicPPL.UnwrapSingletonTransform{Tuple{Int64}}, Tuple{}}}}, DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::LogPriorAccumulator{Float64}, LogJacobian::LogJacobianAccumulator{Float64}, LogLikelihood::LogLikelihoodAccumulator{Float64}}}}(VarNamedTuple(x = VectorValue{Vector{Float64}, DynamicPPL.UnwrapSingletonTransform{Tuple{Int64}}, Tuple{}}([0.1584005226486094], DynamicPPL.UnwrapSingletonTransform{Tuple{Int64}}((1,)), ()), y = VectorValue{Vector{Float64}, DynamicPPL.UnwrapSingletonTransform{Tuple{Int64}}, Tuple{}}([0.18462758115248692], DynamicPPL.UnwrapSingletonTransform{Tuple{Int64}}((1,)), ())), DynamicPPL.AccumulatorTuple{3, @NamedTuple{LogPrior::LogPriorAccumulator{Float64}, LogJacobian::LogJacobianAccumulator{Float64}, LogLikelihood::LogLikelihoodAccumulator{Float64}}}((LogPrior = LogPriorAccumulator(-1.4305346771581844), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0)))))

In the first line, we are populating an empty VarInfo with values. This initialisation always causes the VarInfo to be unlinked, because the behaviour of init!! on a VarInfo is to always retain the transform state of the VarInfo (which is empty at this point, and hence unlinked).

If we did not have a separate transform strategy, then we would have to make sure that init() always returned untransformed values when using InitFromUniform (otherwise we would be filling the VarInfo with transformed values, which is not what we want here). So we have this extra step where InitFromUniform has to generate transformed values, untransform them, and then pass them on to the VarInfo so that it can store untransformed values. (This was indeed the way InitFromUniform was implemented in DynamicPPL <= v0.39).

If we then wanted to link the VarInfo again

vi_linked = link!!(vi, f())
VarInfo {linked=true}
 ├─ values
 │  VarNamedTuple
 │  ├─ x => LinkedVectorValue{Vector{Float64}, ComposedFunction{DynamicPPL.UnwrapSingletonTransform{Tuple{}}, ComposedFunction{Base.Fix1{typeof(broadcast), typeof(exp)}, DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}}}, Tuple{}}([-1.8426285000579532], DynamicPPL.UnwrapSingletonTransform{Tuple{}}(()) ∘ (Base.Fix1{typeof(broadcast), typeof(exp)}(broadcast, exp) ∘ DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}((1,), ())), ())
 │  └─ y => LinkedVectorValue{Vector{Float64}, ComposedFunction{DynamicPPL.UnwrapSingletonTransform{Tuple{}}, ComposedFunction{Base.Fix1{typeof(broadcast), typeof(exp)}, DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}}}, Tuple{}}([-1.689414557713834], DynamicPPL.UnwrapSingletonTransform{Tuple{}}(()) ∘ (Base.Fix1{typeof(broadcast), typeof(exp)}(broadcast, exp) ∘ DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}((1,), ())), ())
 └─ accs
    AccumulatorTuple with 3 accumulators
    ├─ LogPrior => LogPriorAccumulator(-1.4305346771581844)
    ├─ LogJacobian => LogJacobianAccumulator(3.532043057771787)
    └─ LogLikelihood => LogLikelihoodAccumulator(0.0)

then we would have to recompute the forward transform, which frustratingly enough is exactly the same as what InitFromUniform had to undo. So we are calculating the same transformation twice, and evaluating the model twice, to make sure that our VarInfo ends up in the right state.

Instead, now with a separate transform strategy, we can immediately do:

_, vi_linked = init!!(Xoshiro(468), f(), VarInfo(), InitFromUniform(), LinkAll())
vi_linked
VarInfo {linked=true}
 ├─ values
 │  VarNamedTuple
 │  ├─ x => LinkedVectorValue{Vector{Float64}, ComposedFunction{DynamicPPL.UnwrapSingletonTransform{Tuple{}}, ComposedFunction{Base.Fix1{typeof(broadcast), typeof(exp)}, DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}}}, Tuple{}}([-1.8426285000579532], DynamicPPL.UnwrapSingletonTransform{Tuple{}}(()) ∘ (Base.Fix1{typeof(broadcast), typeof(exp)}(broadcast, exp) ∘ DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}((1,), ())), ())
 │  └─ y => LinkedVectorValue{Vector{Float64}, ComposedFunction{DynamicPPL.UnwrapSingletonTransform{Tuple{}}, ComposedFunction{Base.Fix1{typeof(broadcast), typeof(exp)}, DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}}}, Tuple{}}([-1.689414557713834], DynamicPPL.UnwrapSingletonTransform{Tuple{}}(()) ∘ (Base.Fix1{typeof(broadcast), typeof(exp)}(broadcast, exp) ∘ DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}((1,), ())), ())
 └─ accs
    AccumulatorTuple with 3 accumulators
    ├─ LogPrior => LogPriorAccumulator(-1.4305346771581844)
    ├─ LogJacobian => LogJacobianAccumulator(3.532043057771787)
    └─ LogLikelihood => LogLikelihoodAccumulator(0.0)

We can see that we have gotten exactly the same result, but with only running the model once, and only calculating the transformation once. Furthermore, this allows us to remove the inverse transform step inside InitFromUniform: it can simply return a LinkedVectorValue directly, and the transform strategy is then responsible for performing the inverse transform a single time.

Essentially, having a separate transform strategy allows us to:

  1. Free up the initialisation strategy to return whatever kind of AbstractTransformedValue is most convenient for it, without worrying about whether it needs to perform some transform.

  2. Consolidate all the actual transformation in a single function (DynamicPPL.apply_transform_strategy), which allows us to ensure that each tilde-statement involves at most one transformation.