Design of VarInfo
VarInfo
is a fairly simple structure.
DynamicPPL.VarInfo
— Typestruct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo
metadata::Tmeta
logp::Base.RefValue{Tlogp}
num_produce::Base.RefValue{Int}
end
A light wrapper over one or more instances of Metadata
. Let vi
be an instance of VarInfo
. If vi isa VarInfo{<:Metadata}
, then only one Metadata
instance is used for all the sybmols. VarInfo{<:Metadata}
is aliased UntypedVarInfo
. If vi isa VarInfo{<:NamedTuple}
, then vi.metadata
is a NamedTuple
that maps each symbol used on the LHS of ~
in the model to its Metadata
instance. The latter allows for the type specialization of vi
after the first sampling iteration when all the symbols have been observed. VarInfo{<:NamedTuple}
is aliased TypedVarInfo
.
Note: It is the user's responsibility to ensure that each "symbol" is visited at least once whenever the model is called, regardless of any stochastic branching. Each symbol refers to a Julia variable and can be a hierarchical array of many random variables, e.g. x[1] ~ ...
and x[2] ~ ...
both have the same symbol x
.
It contains
- a
logp
field for accumulation of the log-density evaluation, and - a
metadata
field for storing information about the realizations of the different variables.
Representing logp
is fairly straight-forward: we'll just use a Real
or an array of Real
, depending on the context.
Representing metadata
is a bit trickier. This is supposed to contain all the necessary information for each VarName
to enable the different executions of the model + extraction of different properties of interest after execution, e.g. the realization / value corresponding to a variable @varname(x)
.
We want to work with VarName
rather than something like Symbol
or String
as VarName
contains additional structural information, e.g. a Symbol("x[1]")
can be a result of either var"x[1]" ~ Normal()
or x[1] ~ Normal()
; these scenarios are disambiguated by VarName
.
To ensure that VarInfo
is simple and intuitive to work with, we want VarInfo
, and hence the underlying metadata
, to replicate the following functionality of Dict
:
keys(::Dict)
: return all theVarName
s present inmetadata
.haskey(::Dict)
: check if a particularVarName
is present inmetadata
.getindex(::Dict, ::VarName)
: return the realization corresponding to a particularVarName
.setindex!(::Dict, val, ::VarName)
: set the realization corresponding to a particularVarName
.push!(::Dict, ::Pair)
: add a new key-value pair to the container.delete!(::Dict, ::VarName)
: delete the realization corresponding to a particularVarName
.empty!(::Dict)
: delete all realizations inmetadata
.merge(::Dict, ::Dict)
: merge twometadata
structures according to similar rules asDict
.
But for general-purpose samplers, we often want to work with a simple flattened structure, typically a Vector{<:Real}
. One can access a vectorised version of a variable's value with the following vector-like functions:
getindex_internal(::VarInfo, ::VarName)
: get the flattened value of a single variable.getindex_internal(::VarInfo, ::Colon)
: get the flattened values of all variables.getindex_internal(::VarInfo, i::Int)
: geti
th value of the flattened vector of all valuessetindex_internal!(::VarInfo, ::AbstractVector, ::VarName)
: set the flattened value of a variable.setindex_internal!(::VarInfo, val, i::Int)
: set thei
th value of the flattened vector of all valueslength_internal(::VarInfo)
: return the length of the flat representation ofmetadata
.
The functions have _internal
in their name because internally VarInfo
always stores values as vectorised.
Moreover, a link transformation can be applied to a VarInfo
with link!!
(and reversed with invlink!!
), which applies a reversible transformation to the internal storage format of a variable that makes the range of the random variable cover all of Euclidean space. getindex_internal
and setindex_internal!
give direct access to the vectorised value after such a transformation, which is what samplers often need to be able sample in unconstrained space. One can also manually set a transformation by giving setindex_internal!
a fourth, optional argument, that is a function that maps internally stored value to the actual value of the variable.
Finally, we want want the underlying representation used in metadata
to have a few performance-related properties:
- Type-stable when possible, but functional when not.
- Efficient storage and iteration when possible, but functional when not.
The "but functional when not" is important as we want to support arbitrary models, which means that we can't always have these performance properties.
In the following sections, we'll outline how we achieve this in VarInfo
.
Type-stability
Ensuring type-stability is somewhat non-trivial to address since we want this to be the case even when models mix continuous (typically Float64
) and discrete (typically Int
) variables.
Suppose we have an implementation of metadata
which implements the functionality outlined in the previous section. The way we approach this in VarInfo
is to use a NamedTuple
with a separate metadata
for each distinct Symbol
used. For example, if we have a model of the form
using DynamicPPL, Distributions, FillArrays
@model function demo()
x ~ product_distribution(Fill(Bernoulli(0.5), 2))
y ~ Normal(0, 1)
return nothing
end
demo (generic function with 2 methods)
then we construct a type-stable representation by using a NamedTuple{(:x, :y), Tuple{Vx, Vy}}
where
Vx
is a container witheltype
Bool
, andVy
is a container witheltype
Float64
.
Since VarName
contains the Symbol
used in its type, something like getindex(varinfo, @varname(x))
can be resolved to getindex(varinfo.metadata.x, @varname(x))
at compile-time.
For example, with the model above we have
# Type-unstable `VarInfo`
varinfo_untyped = DynamicPPL.untyped_varinfo(
demo(), SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata()
)
typeof(varinfo_untyped.metadata)
DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}
# Type-stable `VarInfo`
varinfo_typed = DynamicPPL.typed_varinfo(demo())
typeof(varinfo_typed.metadata)
@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, typeof(identity)}, Int64}, Vector{Product{Discrete, Bernoulli{Float64}, FillArrays.Fill{Bernoulli{Float64}, 1, Tuple{Base.OneTo{Int64}}}}}, Vector{VarName{:x, typeof(identity)}}, BitVector, Vector{Set{DynamicPPL.Selector}}}, y::DynamicPPL.Metadata{Dict{VarName{:y, typeof(identity)}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:y, typeof(identity)}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}
They both work as expected but one results in concrete typing and the other does not:
varinfo_untyped[@varname(x)], varinfo_untyped[@varname(y)]
(Real[true, true], 0.06631845415545398)
varinfo_typed[@varname(x)], varinfo_typed[@varname(y)]
(Bool[1, 1], 0.4541584437699302)
Notice that the untyped VarInfo
uses Vector{Real}
to store the boolean entries while the typed uses Vector{Bool}
. This is because the untyped version needs the underlying container to be able to handle both the Bool
for x
and the Float64
for y
, while the typed version can use a Vector{Bool}
for x
and a Vector{Float64}
for y
due to its usage of NamedTuple
.
Of course, this NamedTuple
approach is not necessarily going to help us in scenarios where the Symbol
does not correspond to a unique type, e.g.
x[1] ~ Bernoulli(0.5)
x[2] ~ Normal(0, 1)
In this case we'll end up with a NamedTuple((:x,), Tuple{Vx})
where Vx
is a container with eltype
Union{Bool, Float64}
or something worse. This is not type-stable but will still be functional.
In practice, we rarely observe such mixing of types, therefore in DynamicPPL, and more widely in Turing.jl, we use a NamedTuple
approach for type-stability with great success.
Another downside with such a NamedTuple
approach is that if we have a model with lots of tilde-statements, e.g. a ~ Normal()
, b ~ Normal()
, ..., z ~ Normal()
will result in a NamedTuple
with 27 entries, potentially leading to long compilation times.
For these scenarios it can be useful to fall back to "untyped" representations.
Hence we obtain a "type-stable when possible"-representation by wrapping it in a NamedTuple
and partially resolving the getindex
, setindex!
, etc. methods at compile-time. When type-stability is not desired, we can simply use a single metadata
for all VarName
s instead of a NamedTuple
wrapping a collection of metadata
s.
Efficient storage and iteration
Efficient storage and iteration we achieve through implementation of the metadata
. In particular, we do so with DynamicPPL.VarNamedVector
:
DynamicPPL.VarNamedVector
— TypeVarNamedVector
A container that stores values in a vectorised form, but indexable by variable names.
A VarNamedVector
can be thought of as an ordered mapping from VarName
s to pairs of (internal_value, transform)
. Here internal_value
is a vectorised value for the variable and transform
is a function such that transform(internal_value)
is the "original" value of the variable, the one that the user sees. For instance, if the variable has a matrix value, internal_value
could bea flattened Vector
of its elements, and transform
would be a reshape
call.
transform
may implement simply vectorisation, but it may do more. Most importantly, it may implement linking, where the internal storage of a random variable is in a form where all values in Euclidean space are valid. This is useful for sampling, because the sampler can make changes to internal_value
without worrying about constraints on the space of the random variable.
The way to access this storage format directly is through the functions getindex_internal
and setindex_internal
. The transform
argument for setindex_internal
is optional, by default it is either the identity, or the existing transform if a value already exists for this VarName
.
VarNamedVector
also provides a Dict
-like interface that hides away the internal vectorisation. This can be accessed with getindex
and setindex!
. setindex!
only takes the value, the transform is automatically set to be a simple vectorisation. The only notable deviation from the behavior of a Dict
is that setindex!
will throw an error if one tries to set a new value for a variable that lives in a different "space" than the old one (e.g. is of a different type or size). This is because setindex!
does not change the transform of a variable, e.g. preserve linking, and thus the new value must be compatible with the old transform.
For now, a third value is in fact stored for each VarName
: a boolean indicating whether the variable has been transformed to unconstrained Euclidean space or not. This is only in place temporarily due to the needs of our old Gibbs sampler.
Internally, VarNamedVector
stores the values of all variables in a single contiguous vector. This makes some operations more efficient, and means that one can access the entire contents of the internal storage quickly with getindex_internal(vnv, :)
. The other fields of VarNamedVector
are mostly used to keep track of which part of the internal storage belongs to which VarName
.
Fields
varname_to_index
: mapping from aVarName
to its integer index invarnames
,ranges
andtransforms
varnames
: vector ofVarNames
for the variables, wherevarnames[varname_to_index[vn]] == vn
ranges
: vector of index ranges invals
corresponding tovarnames
; eachVarName
vn
has a single index or a set of contiguous indices, such that the values ofvn
can be found atvals[ranges[varname_to_index[vn]]]
vals
: vector of values of all variables; the value(s) ofvn
is/arevals[ranges[varname_to_index[vn]]]
transforms
: vector of transformations, so thattransforms[varname_to_index[vn]]
is a callable that transforms the value ofvn
back to its original space, undoing any linking and vectorisation
is_unconstrained
: vector of booleans indicating whether a variable has been transformed to unconstrained Euclidean space or not, i.e. whether its domain is all ofℝ^ⁿ
. Havingis_unconstrained[varname_to_index[vn]] == false
does not necessarily mean that a variable is constrained, but rather that it's not guaranteed to not be.
num_inactive
: mapping from a variable index to the number of inactive entries for that variable. Inactive entries are elements invals
that are not part of the value of any variable. They arise when a variable is set to a new value with a different dimension, in-place. Inactive entries always come after the last active entry for the given variable. See the extended help with??VarNamedVector
for more details.
Extended help
The values for different variables are internally all stored in a single vector. For instance,
julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!, update!, getindex_internal
julia> vnv = VarNamedVector();
julia> setindex!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x));
julia> setindex!(vnv, reshape(1:6, (2,3)), @varname(y));
julia> vnv.vals
10-element Vector{Real}:
0.0
0.0
0.0
0.0
1
2
3
4
5
6
The varnames
, ranges
, and varname_to_index
fields keep track of which value belongs to which variable. The transforms
field stores the transformations that needed to transform the vectorised internal storage back to its original form:
julia> vnv.transforms[vnv.varname_to_index[@varname(y)]] == DynamicPPL.ReshapeTransform((6,), (2,3))
true
If a variable is updated with a new value that is of a smaller dimension than the old value, rather than resizing vnv.vals
, some elements in vnv.vals
are marked as inactive.
julia> update!(vnv, [46.0, 48.0], @varname(x))
julia> vnv.vals
10-element Vector{Real}:
46.0
48.0
0.0
0.0
1
2
3
4
5
6
julia> println(vnv.num_inactive);
OrderedDict(1 => 2)
This helps avoid unnecessary memory allocations for values that repeatedly change dimension. The user does not have to worry about the inactive entries as long as they use functions like setindex!
and getindex!
rather than directly accessing vnv.vals
.
julia> vnv[@varname(x)]
2-element Vector{Float64}:
46.0
48.0
julia> getindex_internal(vnv, :)
8-element Vector{Real}:
46.0
48.0
1
2
3
4
5
6
In a DynamicPPL.VarNamedVector{<:VarName,T}
, we achieve the desiderata by storing the values for different VarName
s contiguously in a Vector{T}
and keeping track of which ranges correspond to which VarName
s.
This does require a bit of book-keeping, in particular when it comes to insertions and deletions. Internally, this is handled by assigning each VarName
a unique Int
index in the varname_to_index
field, which is then used to index into the following fields:
varnames::Vector{<:VarName}
: theVarName
s in the order they appear in theVector{T}
.ranges::Vector{UnitRange{Int}}
: the ranges of indices in theVector{T}
that correspond to eachVarName
.transforms::Vector
: the transforms associated with eachVarName
.
Mutating functions, e.g. setindex_internal!(vnv::VarNamedVector, val, vn::VarName)
, are then treated according to the following rules:
If
vn
is not already present: add it to the end ofvnv.varnames
, add theval
to the underlyingvnv.vals
, etc.If
vn
is already present invnv
:- If
val
has the same length as the existing value forvn
: replace existing value. - If
val
has a smaller length than the existing value forvn
: replace existing value and mark the remaining indices as "inactive" by increasing the entry invnv.num_inactive
field. - If
val
has a larger length than the existing value forvn
: expand the underlyingvnv.vals
to accommodate the new value, update allVarName
s occuring aftervn
, and update thevnv.ranges
to point to the new range forvn
.
- If
This means that VarNamedVector
is allowed to grow as needed, while "shrinking" (i.e. insertion of smaller elements) is handled by simply marking the redundant indices as "inactive". This turns out to be efficient for use-cases that we are generally interested in.
For example, we want to optimize code-paths which effectively boil down to inner-loop in the following example:
# Construct a `VarInfo` with types inferred from `model`.
varinfo = VarInfo(model)
# Repeatedly sample from `model`.
for _ in 1:num_samples
rand!(rng, model, varinfo)
# Do something with `varinfo`.
# ...
end
There are typically a few scenarios where we encounter changing representation sizes of a random variable x
:
- We're working with a transformed version
x
which is represented in a lower-dimensional space, e.g. transforming ax ~ LKJ(2, 1)
to unconstrainedy = f(x)
takes us from 2-by-2Matrix{Float64}
to a 1-lengthVector{Float64}
. x
has a random size, e.g. in a mixture model with a prior on the number of components. Here the size ofx
can vary widly between every realization of theModel
.
In scenario (1), we're usually shrinking the representation of x
, and so we end up not making any allocations for the underlying Vector{T}
but instead just marking the redundant part as "inactive".
In scenario (2), we end up increasing the allocated memory for the randomly sized x
, eventually leading to a vector that is large enough to hold realizations without needing to reallocate. But this can still lead to unnecessary memory usage, which might be undesirable. Hence one has to make a decision regarding the trade-off between memory usage and performance for the use-case at hand.
To help with this, we have the following functions:
DynamicPPL.has_inactive
— FunctionDynamicPPL.num_inactive
— Functionnum_inactive(vnv::VarNamedVector)
Return the number of inactive entries in vnv
.
See also: has_inactive
, num_allocated
num_inactive(vnv::VarNamedVector, vn::VarName)
Returns the number of inactive entries for vn
in vnv
.
DynamicPPL.num_allocated
— Functionnum_allocated(vnv::VarNamedVector)
num_allocated(vnv::VarNamedVector[, vn::VarName])
num_allocated(vnv::VarNamedVector[, idx::Int])
Return the number of allocated entries in vnv
, both active and inactive.
If either a VarName
or an Int
index is specified, only count entries allocated for that variable.
Allocated entries take up memory in vnv.vals
, but, if inactive, may not currently hold any meaningful data. One can remove them with contiguify!
, but doing so may cause more memory allocations in the future if variables change dimension.
DynamicPPL.is_contiguous
— Functionis_contiguous(vnv::VarNamedVector)
Returns true
if the underlying data of vnv
is stored in a contiguous array.
This is equivalent to negating has_inactive(vnv)
.
DynamicPPL.contiguify!
— Functioncontiguify!(vnv::VarNamedVector)
Re-contiguify the underlying vector and shrink if possible.
Examples
julia> using DynamicPPL: VarNamedVector, @varname, contiguify!, update!, has_inactive
julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0, 3.0], @varname(y) => [3.0]);
julia> update!(vnv, [23.0, 24.0], @varname(x));
julia> has_inactive(vnv)
true
julia> length(vnv.vals)
4
julia> contiguify!(vnv);
julia> has_inactive(vnv)
false
julia> length(vnv.vals)
3
julia> vnv[@varname(x)] # All the values are still there.
2-element Vector{Float64}:
23.0
24.0
For example, one might encounter the following scenario:
vnv = DynamicPPL.VarNamedVector(@varname(x) => [true])
println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))")
for i in 1:5
x = fill(true, rand(1:100))
DynamicPPL.update!(vnv, x, @varname(x))
println(
"After insertion #$(i) of length $(length(x)): number of allocated entries $(DynamicPPL.num_allocated(vnv))",
)
end
Before insertion: number of allocated entries 1
After insertion #1 of length 37: number of allocated entries 37
After insertion #2 of length 16: number of allocated entries 37
After insertion #3 of length 86: number of allocated entries 86
After insertion #4 of length 21: number of allocated entries 86
After insertion #5 of length 8: number of allocated entries 86
We can then insert a call to DynamicPPL.contiguify!
after every insertion whenever the allocation grows too large to reduce overall memory usage:
vnv = DynamicPPL.VarNamedVector(@varname(x) => [true])
println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))")
for i in 1:5
x = fill(true, rand(1:100))
DynamicPPL.update!(vnv, x, @varname(x))
if DynamicPPL.num_allocated(vnv) > 10
DynamicPPL.contiguify!(vnv)
end
println(
"After insertion #$(i) of length $(length(x)): number of allocated entries $(DynamicPPL.num_allocated(vnv))",
)
end
Before insertion: number of allocated entries 1
After insertion #1 of length 48: number of allocated entries 48
After insertion #2 of length 86: number of allocated entries 86
After insertion #3 of length 22: number of allocated entries 22
After insertion #4 of length 40: number of allocated entries 40
After insertion #5 of length 91: number of allocated entries 91
This does incur a runtime cost as it requires re-allocation of the ranges
in addition to a resize!
of the underlying Vector{T}
. However, this also ensures that the the underlying Vector{T}
is contiguous, which is important for performance. Hence, if we're about to do a lot of work with the VarNamedVector
without insertions, etc., it can be worth it to do a sweep to ensure that the underlying Vector{T}
is contiguous.
Higher-dimensional arrays, e.g. Matrix
, are handled by simply vectorizing them before storing them in the Vector{T}
, and composing the VarName
's transformation with a DynamicPPL.ReshapeTransform
.
Continuing from the example from the previous section, we can use a VarInfo
with a VarNamedVector
as the metadata
field:
# Type-unstable
varinfo_untyped_vnv = DynamicPPL.VectorVarInfo(varinfo_untyped)
varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)]
(Real[true, true], 0.06631845415545398)
# Type-stable
varinfo_typed_vnv = DynamicPPL.VectorVarInfo(varinfo_typed)
varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)]
(Bool[1, 1], 0.4541584437699302)
If we now try to delete!
@varname(x)
haskey(varinfo_untyped_vnv, @varname(x))
true
DynamicPPL.has_inactive(varinfo_untyped_vnv.metadata)
false
# `delete!`
DynamicPPL.delete!(varinfo_untyped_vnv.metadata, @varname(x))
DynamicPPL.has_inactive(varinfo_untyped_vnv.metadata)
false
haskey(varinfo_untyped_vnv, @varname(x))
false
Or insert a differently-sized value for @varname(x)
DynamicPPL.insert!(varinfo_untyped_vnv.metadata, fill(true, 1), @varname(x))
varinfo_untyped_vnv[@varname(x)]
1-element Vector{Real}:
true
DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x))
1
DynamicPPL.update!(varinfo_untyped_vnv.metadata, fill(true, 4), @varname(x))
varinfo_untyped_vnv[@varname(x)]
4-element Vector{Real}:
true
true
true
true
DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x))
4
Performance summary
In the end, we have the following "rough" performance characteristics for VarNamedVector
:
Method | Is blazingly fast? |
---|---|
getindex | ${\color{green} \checkmark}$ |
setindex! on a new VarName | ${\color{green} \checkmark}$ |
delete! | ${\color{red} \times}$ |
update! on existing VarName | ${\color{green} \checkmark}$ if smaller or same size / ${\color{red} \times}$ if larger size |
values_as(::VarNamedVector, Vector{T}) | ${\color{green} \checkmark}$ if contiguous / ${\color{orange} \div}$ otherwise |
Other methods
DynamicPPL.replace_raw_storage
— Methodreplace_raw_storage(vnv::VarNamedVector, vals::AbstractVector)
Replace the values in vnv
with vals
, as they are stored internally.
This is useful when we want to update the entire underlying vector of values in one go or if we want to change the how the values are stored, e.g. alter the eltype
.
This replaces the raw underlying values, and so care should be taken when using this function. For example, if vnv
has any inactive entries, then the provided vals
should also contain the inactive entries to avoid unexpected behavior.
Examples
julia> using DynamicPPL: VarNamedVector, replace_raw_storage
julia> vnv = VarNamedVector(@varname(x) => [1.0]);
julia> replace_raw_storage(vnv, [2.0])[@varname(x)] == [2.0]
true
This is also useful when we want to differentiate wrt. the values using automatic differentiation, e.g. ForwardDiff.jl.
julia> using ForwardDiff: ForwardDiff
julia> f(x) = sum(abs2, replace_raw_storage(vnv, x)[@varname(x)])
f (generic function with 1 method)
julia> ForwardDiff.gradient(f, [1.0])
1-element Vector{Float64}:
2.0
DynamicPPL.values_as
— Methodvalues_as(vnv::VarNamedVector[, T])
Return the values/realizations in vnv
as type T
, if implemented.
If no type T
is provided, return values as stored in vnv
.
Examples
julia> using DynamicPPL: VarNamedVector
julia> vnv = VarNamedVector(@varname(x) => 1, @varname(y) => [2.0]);
julia> values_as(vnv) == [1.0, 2.0]
true
julia> values_as(vnv, Vector{Float32}) == Vector{Float32}([1.0, 2.0])
true
julia> values_as(vnv, OrderedDict) == OrderedDict(@varname(x) => 1.0, @varname(y) => [2.0])
true
julia> values_as(vnv, NamedTuple) == (x = 1.0, y = [2.0])
true