# Turing Compiler Design

In this section, the current design of Turing's model "compiler" is described which enables Turing to perform various types of Bayesian inference without changing the model definition. The "compiler" is essentially just a macro that rewrites the user's model definition to a function that generates a `Model`

struct that Julia's dispatch can operate on and that Julia's compiler can successfully do type inference on for efficient machine code generation.

# Overview

The following terminology will be used in this section:

`D`

: observed data variables conditioned upon in the posterior,`P`

: parameter variables distributed according to the prior distributions, these will also be referred to as random variables,`Model`

: a fully defined probabilistic model with input data

`Turing`

's `@model`

macro rewrites the user-provided function definition such that it can be used to instantiate a `Model`

by passing in the observed data `D`

.

The following are the main jobs of the `@model`

macro:

- Parse
`~`

and`.~`

lines, e.g.`y .~ Normal.(c*x, 1.0)`

- Figure out if a variable belongs to the data
`D`

and or to the parameters`P`

- Enable the handling of missing data variables in
`D`

when defining a`Model`

and treating them as parameter variables in`P`

instead - Enable the tracking of random variables using the data structures
`VarName`

and`VarInfo`

- Change
`~`

/`.~`

lines with a variable in`P`

on the LHS to a call to`tilde_assume`

or`dot_tilde_assume`

- Change
`~`

/`.~`

lines with a variable in`D`

on the LHS to a call to`tilde_observe`

or`dot_tilde_observe`

- Enable type stable automatic differentiation of the model using type parameters

## The model

A `model::Model`

is a callable struct that one can sample from by calling

```
(model::Model)([rng, varinfo, sampler, context])
```

where `rng`

is a random number generator (default: `Random.default_rng()`

), `varinfo`

is a data structure that stores information
about the random variables (default: `DynamicPPL.VarInfo()`

), `sampler`

is a sampling algorithm (default: `DynamicPPL.SampleFromPrior()`

),
and `context`

is a sampling context that can, e.g., modify how the log probability is accumulated (default: `DynamicPPL.DefaultContext()`

).

Sampling resets the log joint probability of `varinfo`

and increases the evaluation counter of `sampler`

. If `context`

is a `LikelihoodContext`

,
only the log likelihood will be accumulated. With the `DefaultContext`

the log joint probability of `P`

and `D`

is accumulated.

The `Model`

struct contains the three internal fields `f`

, `args`

and `defaults`

.
When `model::Model`

is called, then the internal function `model.f`

is called as `model.f(rng, varinfo, sampler, context, model.args...)`

(for multithreaded sampling, instead of `varinfo`

a threadsafe wrapper is passed to `model.f`

).
The positional and keyword arguments that were passed to the user-defined model function when the model was created are saved as a `NamedTuple`

in `model.args`

. The default values of the positional and keyword arguments of the user-defined model functions, if any, are saved as a `NamedTuple`

in `model.defaults`

. They are used for constructing model instances with different arguments by the `logprob`

and `prob`

string macros.

# Example

Let's take the following model as an example:

```
@model function gauss(
x=missing, y=1.0, ::Type{TV}=Vector{Float64}
) where {TV<:AbstractVector}
if x === missing
x = TV(undef, 3)
end
p = TV(undef, 2)
p[1] ~ InverseGamma(2, 3)
p[2] ~ Normal(0, 1.0)
@. x[1:2] ~ Normal(p[2], sqrt(p[1]))
x[3] ~ Normal()
return y ~ Normal(p[2], sqrt(p[1]))
end
```

The above call of the `@model`

macro defines the function `gauss`

with positional arguments `x`

, `y`

, and `::Type{TV}`

, rewritten in
such a way that every call of it returns a `model::Model`

. Note that only the function body is modified by the `@model`

macro, and the
function signature is left untouched. It is also possible to implement models with keyword arguments such as

```
@model function gauss(
::Type{TV}=Vector{Float64}; x=missing, y=1.0
) where {TV<:AbstractVector}
return ...
end
```

This would allow us to generate a model by calling `gauss(; x = rand(3))`

.

If an argument has a default value `missing`

, it is treated as a random variable. For variables which require an initialization because we
need to loop or broadcast over its elements, such as `x`

above, the following needs to be done:

```
if x === missing
x = ...
end
```

Note that since `gauss`

behaves like a regular function it is possible to define additional dispatches in a second step as well. For
instance, we could achieve the same behaviour by

```
@model function gauss(x, y=1.0, ::Type{TV}=Vector{Float64}) where {TV<:AbstractVector}
p = TV(undef, 2)
return ...
end
function gauss(::Missing, y=1.0, ::Type{TV}=Vector{Float64}) where {TV<:AbstractVector}
return gauss(TV(undef, 3), y, TV)
end
```

If `x`

is sampled as a whole from a distribution and not indexed, e.g., `x ~ Normal(...)`

or `x ~ MvNormal(...)`

,
there is no need to initialize it in an `if`

-block.

## Step 1: Break up the model definition

First, the `@model`

macro breaks up the user-provided function definition using `DynamicPPL.build_model_info`

. This function
returns a dictionary consisting of:

`allargs_exprs`

: The expressions of the positional and keyword arguments, without default values.`allargs_syms`

: The names of the positional and keyword arguments, e.g.,`[:x, :y, :TV]`

above.`allargs_namedtuple`

: An expression that constructs a`NamedTuple`

of the positional and keyword arguments, e.g.,`:((x = x, y = y, TV = TV))`

above.`defaults_namedtuple`

: An expression that constructs a`NamedTuple`

of the default positional and keyword arguments, if any, e.g.,`:((x = missing, y = 1, TV = Vector{Float64}))`

above.`modeldef`

: A dictionary with the name, arguments, and function body of the model definition, as returned by`MacroTools.splitdef`

.

## Step 2: Generate the body of the internal model function

In a second step, `DynamicPPL.generate_mainbody`

generates the main part of the transformed function body using the user-provided function body
and the provided function arguments, without default values, for figuring out if a variable denotes an observation or a random variable.
Hereby the function `DynamicPPL.generate_tilde`

replaces the `L ~ R`

lines in the model and the function `DynamicPPL.generate_dot_tilde`

replaces
the `@. L ~ R`

and `L .~ R`

lines in the model.

In the above example, `p[1] ~ InverseGamma(2, 3)`

is replaced with something similar to

```
#= REPL[25]:6 =#
begin
var"##tmpright#323" = InverseGamma(2, 3)
var"##tmpright#323" isa Union{Distribution,AbstractVector{<:Distribution}} || throw(
ArgumentError(
"Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions.",
),
)
var"##vn#325" = (DynamicPPL.VarName)(:p, ((1,),))
var"##inds#326" = ((1,),)
p[1] = (DynamicPPL.tilde_assume)(
_rng,
_context,
_sampler,
var"##tmpright#323",
var"##vn#325",
var"##inds#326",
_varinfo,
)
end
```

Here the first line is a so-called line number node that enables more helpful error messages by providing users with the exact location
of the error in their model definition. Then the right hand side (RHS) of the `~`

is assigned to a variable (with an automatically generated name).
We check that the RHS is a distribution or an array of distributions, otherwise an error is thrown.
Next we extract a compact representation of the variable with its name and index (or indices). Finally, the `~`

expression is replaced with
a call to `DynamicPPL.tilde_assume`

since the compiler figured out that `p[1]`

is a random variable using the following
heuristic:

- If the symbol on the LHS of
`~`

,`:p`

in this case, is not among the arguments to the model,`(:x, :y, :T)`

in this case, it is a random variable. - If the symbol on the LHS of
`~`

,`:p`

in this case, is among the arguments to the model but has a value of`missing`

, it is a random variable. - If the value of the LHS of
`~`

,`p[1]`

in this case, is`missing`

, then it is a random variable. - Otherwise, it is treated as an observation.

The `DynamicPPL.tilde_assume`

function takes care of sampling the random variable, if needed, and updating its value and the accumulated log joint
probability in the `_varinfo`

object. If `L ~ R`

is an observation, `DynamicPPL.tilde_observe`

is called with the same arguments except the
random number generator `_rng`

(since observations are never sampled).

A similar transformation is performed for expressions of the form `@. L ~ R`

and `L .~ R`

. For instance,
`@. x[1:2] ~ Normal(p[2], sqrt(p[1]))`

is replaced with

```
#= REPL[25]:8 =#
begin
var"##tmpright#331" = Normal.(p[2], sqrt.(p[1]))
var"##tmpright#331" isa Union{Distribution,AbstractVector{<:Distribution}} || throw(
ArgumentError(
"Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions.",
),
)
var"##vn#333" = (DynamicPPL.VarName)(:x, ((1:2,),))
var"##inds#334" = ((1:2,),)
var"##isassumption#335" = begin
let var"##vn#336" = (DynamicPPL.VarName)(:x, ((1:2,),))
if !((DynamicPPL.inargnames)(var"##vn#336", _model)) ||
(DynamicPPL.inmissings)(var"##vn#336", _model)
true
else
x[1:2] === missing
end
end
end
if var"##isassumption#335"
x[1:2] .= (DynamicPPL.dot_tilde_assume)(
_rng,
_context,
_sampler,
var"##tmpright#331",
x[1:2],
var"##vn#333",
var"##inds#334",
_varinfo,
)
else
(DynamicPPL.dot_tilde_observe)(
_context,
_sampler,
var"##tmpright#331",
x[1:2],
var"##vn#333",
var"##inds#334",
_varinfo,
)
end
end
```

The main difference in the expanded code between `L ~ R`

and `@. L ~ R`

is that the former doesn't assume `L`

to be defined, it can be a new Julia variable in the scope, while the latter assumes `L`

already exists. Moreover, `DynamicPPL.dot_tilde_assume`

and `DynamicPPL.dot_tilde_observe`

are called
instead of `DynamicPPL.tilde_assume`

and `DynamicPPL.tilde_observe`

.

## Step 3: Replace the user-provided function body

Finally, we replace the user-provided function body using `DynamicPPL.build_output`

. This function uses `MacroTools.combinedef`

to reassemble
the user-provided function with a new function body. In the modified function body an anonymous function is created whose function body
was generated in step 2 above and whose arguments are

- a random number generator
`_rng`

, - a model
`_model`

, - a datastructure
`_varinfo`

, - a sampler
`_sampler`

, - a sampling context
`_context`

, - and all positional and keyword arguments of the user-provided model function as positional arguments
without any default values. Finally, in the new function body a
`model::Model`

with this anonymous function as internal function is returned.

`VarName`

In order to track random variables in the sampling process, `Turing`

uses the `VarName`

struct which acts as a random variable identifier generated at runtime. The `VarName`

of a random variable is generated from the expression on the LHS of a `~`

statement when the symbol on the LHS is in the set `P`

of unobserved random variables. Every `VarName`

instance has a type parameter `sym`

which is the symbol of the Julia variable in the model that the random variable belongs to. For example, `x[1] ~ Normal()`

will generate an instance of `VarName{:x}`

assuming `x`

is an unobserved random variable. Every `VarName`

also has a field `indexing`

, which stores the indices required to access the random variable from the Julia variable indicated by `sym`

as a tuple of tuples. Each element of the tuple thereby contains the indices of one indexing operation (`VarName`

also supports hierarchical arrays and range indexing). Some examples:

`x ~ Normal()`

will generate a`VarName(:x, ())`

.`x[1] ~ Normal()`

will generate a`VarName(:x, ((1,),))`

.`x[:,1] ~ MvNormal(zeros(2), I)`

will generate a`VarName(:x, ((Colon(), 1),))`

.`x[:,1][1+1] ~ Normal()`

will generate a`VarName(:x, ((Colon(), 1), (2,)))`

.

The easiest way to manually construct a `VarName`

is to use the `@varname`

macro on an indexing expression, which will take the `sym`

value from the actual variable name, and put the index values appropriately into the constructor.

`VarInfo`

## Overview

`VarInfo`

is the data structure in `Turing`

that facilitates tracking random variables and certain metadata about them that are required for sampling. For instance, the distribution of every random variable is stored in `VarInfo`

because we need to know the support of every random variable when sampling using HMC for example. Random variables whose distributions have a constrained support are transformed using a bijector from Bijectors.jl so that the sampling happens in the unconstrained space. Different samplers require different metadata about the random variables.

The definition of `VarInfo`

in `Turing`

is:

```
struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo
metadata::Tmeta
logp::Base.RefValue{Tlogp}
num_produce::Base.RefValue{Int}
end
```

Based on the type of `metadata`

, the `VarInfo`

is either aliased `UntypedVarInfo`

or `TypedVarInfo`

. `metadata`

can be either a subtype of the union type `Metadata`

or a `NamedTuple`

of multiple such subtypes. Let `vi`

be an instance of `VarInfo`

. If `vi isa VarInfo{<:Metadata}`

, then it is called an `UntypedVarInfo`

. If `vi isa VarInfo{<:NamedTuple}`

, then `vi.metadata`

would be a `NamedTuple`

mapping each symbol in `P`

to an instance of `Metadata`

. `vi`

would then be called a `TypedVarInfo`

. The other fields of `VarInfo`

include `logp`

which is used to accumulate the log probability or log probability density of the variables in `P`

and `D`

. `num_produce`

keeps track of how many observations have been made in the model so far. This is incremented when running a `~`

statement when the symbol on the LHS is in `D`

.

`Metadata`

The `Metadata`

struct stores some metadata about the random variables sampled. This helps
query certain information about a variable such as: its distribution, which samplers
sample this variable, its value and whether this value is transformed to real space or
not. Let `md`

be an instance of `Metadata`

:

`md.vns`

is the vector of all`VarName`

instances. Let`vn`

be an arbitrary element of`md.vns`

`md.idcs`

is the dictionary that maps each`VarName`

instance to its index in`md.vns`

,`md.ranges`

,`md.dists`

,`md.orders`

and`md.flags`

.`md.vns[md.idcs[vn]] == vn`

.`md.dists[md.idcs[vn]]`

is the distribution of`vn`

.`md.gids[md.idcs[vn]]`

is the set of algorithms used to sample`vn`

. This is used in the Gibbs sampling process.`md.orders[md.idcs[vn]]`

is the number of`observe`

statements before`vn`

is sampled.`md.ranges[md.idcs[vn]]`

is the index range of`vn`

in`md.vals`

.`md.vals[md.ranges[md.idcs[vn]]]`

is the linearized vector of values of corresponding to`vn`

.`md.flags`

is a dictionary of true/false flags.`md.flags[flag][md.idcs[vn]]`

is the value of`flag`

corresponding to`vn`

.

Note that in order to make `md::Metadata`

type stable, all the `md.vns`

must have the same symbol and distribution type. However, one can have a single Julia variable, e.g. `x`

, that is a matrix or a hierarchical array sampled in partitions, e.g. `x[1][:] ~ MvNormal(zeros(2), I); x[2][:] ~ MvNormal(ones(2), I)`

. The symbol `x`

can still be managed by a single `md::Metadata`

without hurting the type stability since all the distributions on the RHS of `~`

are of the same type.

However, in `Turing`

models one cannot have this restriction, so we must use a type unstable `Metadata`

if we want to use one `Metadata`

instance for the whole model. This is what `UntypedVarInfo`

does. A type unstable `Metadata`

will still work but will have inferior performance.

To strike a balance between flexibility and performance when constructing the `spl::Sampler`

instance, the model is first run by sampling the parameters in `P`

from their priors using an `UntypedVarInfo`

, i.e. a type unstable `Metadata`

is used for all the variables. Then once all the symbols and distribution types have been identified, a `vi::TypedVarInfo`

is constructed where `vi.metadata`

is a `NamedTuple`

mapping each symbol in `P`

to a specialized instance of `Metadata`

. So as long as each symbol in `P`

is sampled from only one type of distributions, `vi::TypedVarInfo`

will have fully concretely typed fields which brings out the peak performance of Julia.