Skip to content

Expression


abstract class Expression<Value> < Delay

Abstract interface for evaluating and differentiating expressions.

  • Value: Result type.
classDiagram Expression <|-- Random Expression <|-- BoxedValue Expression <|-- BoxedForm link Expression "../Expression/" link Random "../Random/" link BoxedValue "../BoxedValue/" link BoxedForm "../BoxedForm/" class Random { random argument } class BoxedValue { constant value } class BoxedForm { expression }

Delayed expressions (alternatively: lazy expressions, compute graphs, expression templates) encode mathematical expressions that can be evaluated, differentiated, and moved (using Markov kernels). They are assembled using mathematical operators and functions much like ordinary expressions, but where one or more operands or arguments are Random objects. Where an ordinary expression is evaluated immediately into a result, delayed expressions evaluate to further Expression objects.

Simple delayed expressions are trees of subexpressions with Random or Boxed objects at the leaves. In general, however, a delayed expression can be a directed acyclic graph, as subexpressions may be reused during assembly.

Simple use

Tip

Call value() on an Expression to evaluate it.

The simplest use case of a delayed expression is to assemble it and then evaluate it by calling value(). Evaluations are memoized, so further calls value() do not require re-evaluation, they simply return the memoized value.

Once value() is called on an Expression, it and all subexpressions that constitute it are considered constant. This particularly affects any Random objects in the expression, the value of which can no longer be alterated.

Advanced use

More elaborate use cases include computing gradients and applying Markov kernels. Call eval() to evaluate the expression in the same way as for value(), but without rendering it constant. Any Random objects in the expression that have not previously been rendered constant by a call to value() are then considered arguments eligible for moving.

After updating the value of arguments, use reval() to re-evaluate the expression with those new values.

At any time, use grad() to compute the gradient of an expression with respect to its arguments. The gradient is accumulated into those arguments (the Random objects).

Use value(), not eval(), unless you are taking responsibility for correctness (e.g. moving arguments in a manner invariant to some target distribution, using a Markov kernel). Otherwise, program behavior may lack self-consistency. Consider, for example:

if x.value() >= 0.0 {
  doThis();
} else {
  doThat();
}

This is correct usage. Using eval() instead of value() here allows some other part of the code to later change the value of the random variable x to a negative value, and the program lacks self-consistency: it executed doThis() instead of doThat() based on a previous value of x.

Attention

Correctness is the programmer's responsibility when using the advanced interface.

Member Functions

Name Description
isRandom Is this a Random expression?
isConstant Is this a constant expression?
rows Number of rows in result.
columns Number of columns in result.
length Length of result.
size Size of result.
value Get result and render constant.
peek Get result.
eval Evaluate.
reval Re-evaluate.
grad Evaluate gradient of the expression with respect to its arguments.
label Label generations.
constant Make older generations constant.
constant Render the entire expression constant.

Member Function Details

columns

abstract function columns() -> Integer

Number of columns in result.

constant

abstract function constant(gen:Integer)

Make older generations constant.

  • gen: Generation limit.

Removes subexpressions that are labeled with a generation less than gen.

abstract function constant()

Render the entire expression constant.

eval

abstract function eval() -> Value

Evaluate.

grad

abstract function grad(d:Value)

Evaluate gradient of the expression with respect to its arguments.

  • d: Upstream gradient.

eval() must have been called before calling grad().

The expression is treated as a function, and the arguments defined as those Random objects in the expression that are not constant.

If the expression encodes

x_n = f(x_0) = (f_n \circ \cdots \circ f_1)(x_0),

and this particular object encodes one of those functions x_i = f_i(x_{i-1}), the upstream gradient d is

\frac{\partial (f_n \circ \cdots \circ f_{i+1})} {\partial x_i}\left(x_i\right).

grad() then computes:

\frac{\partial (f_n \circ \cdots \circ f_{i})} {\partial x_{i-1}}\left(x_{i-1}\right),

and passes the result to the next step in the chain, which encodes f_{i-1}. The argument that encodes x_0 keeps the final result---it is a Random object.

Reverse-mode automatic differentiation is used. The previous call to eval() constitutes the forward pass, and the call to grad() the backward pass.

Because expressions are, in general, directed acyclic graphs, a counting mechanism is used to accumulate upstream gradients into any shared subexpressions before visiting them. This ensures that each subexpression is visited only once, not as many times as it is used. Mathematically, this is equivalent to factorizing out the subexpression as a common factor in the application of the chain rule. It turns out to be particularly important when expressions include posterior parameters after multiple Bayesian updates applied by automatic conditioning. Such expressions can have many common subexpressions, and the counting mechanism results in automatic differentiation of complexity O(N) in the number of updates, as opposed to O(N^2) otherwise.

isConstant

abstract function isConstant() -> Boolean

Is this a constant expression?

isRandom

abstract function isRandom() -> Boolean

Is this a Random expression?

label

abstract function label(gen:Integer)

Label generations.

  • gen: Generation.

Recursively labels all subexpressions with the given generation, terminating on subexpressions that are already labeled.

length

final function length() -> Integer

Length of result. This is synonymous with rows().

peek

abstract function peek() -> Value

Get result.

Returns: The result.

reval

abstract function reval() -> Value

Re-evaluate.

rows

abstract function rows() -> Integer

Number of rows in result.

size

final function size() -> Integer

Size of result. This is equal to rows()*columns().

value

abstract function value() -> Value

Get result and render constant.

Returns: The result.