TME in JaX

Taylor moment expansion (TME) in JaX.

For math details, please see the docstring of tme.base_sympy.

Functions

generator()

Infinitesimal generator. This is a helper function around generator_power().

generator_power()

Iterations/power of infinitesimal generators.

mean_and_cov()

TME approximation for mean and covariance. In case you just want to compute the mean, use function expectation() with argument phi fed by an identity function.

expectation()

TME approximation for any expectation of the form \(\mathbb{E}[\phi(X(t + \Delta t)) \mid X(t)]\).

References

See the docstring of tme.base_sympy.

Authors

Adrien Corenflos and Zheng Zhao, 2021

tme.base_jax.expectation(phi, x, dt, drift, dispersion, order=3)[source]

TME approximation of expectation on any target function \(\phi\).

For math details, see the docstring of tme.base_sympy.expectation().

Parameters
phiCallable (d,) -> (…)

Target function (must be sufficiently smooth depending on the order).

xjnp.ndarray (d, )

The state at which the generator is evaluated (i.e., the \(x\) in \(\mathbb{E}[\phi(X(t + \Delta t)) \mid X(t)=x]\)).

dtfloat

Time interval.

driftCallable (d,) -> (d,)

SDE drift coefficient.

dispersionCallable (d,) -> (d, w)

SDE dispersion coefficient, where w stands for the dimension of the Wiener process.

orderint

Order of TME. Must be >=0. For the relationship between the expansion order and SDE coefficient smoothness, see, Zhao (2021).

Returns
jnp.ndarray (…)

TME approximation of \(\mathbb{E}[\phi(X(t + \Delta t)) \mid X(t)]\). The output shape is consistence with the input shape of phi.

Return type

ndarray

tme.base_jax.generator(phi, drift, dispersion)[source]

Infinitesimal generator for diffusion processes in Ito’s SDE constructions.

\[(\mathcal{A}\phi)(x) = \sum^d_{i=1} a_i(x)\,\frac{\partial \phi}{\partial x_i}(x) + \frac{1}{2}\, \sum^d_{i,j=1} \Gamma_{ij}(x) \, \frac{\partial^2 \phi}{\partial x_i \, \partial x_j}(x),\]

where \(\phi\colon \mathbb{R}^d \to \mathbb{R}\) must be sufficiently smooth function depending on the expansion order, and \(\Gamma(x) = b(x) \, b(x)^\top\).

This is a helper function around generator_power().

Parameters
phiCallable (d,) -> (…)

Target function.

driftCallable (d,) -> (d,)

SDE drift coefficient.

dispersionCallable (d,) -> (d, w)

SDE dispersion coefficient, where w stands for the dimension of the Wiener process.

Returns
Callable (…)

A callable function which carries out \(x \mapsto \mathcal{A}\phi\). The output shape of this function is the same as phi.

Return type

Callable

tme.base_jax.generator_power(phi, drift, dispersion, order=1)[source]

Iterations/power of infinitesimal generator.

For math details, see the docstring of tme.base_sympy.generator_power().

Parameters
phiCallable (d,) -> (…)

Target function.

driftCallable (d,) -> (d,)

SDE drift coefficient.

dispersionCallable (d,) -> (d, w)

SDE dispersion coefficient, where w stands for the dimension of the Wiener process.

orderint, optional

Number of generator iterations. Must be >=0. Default is 1, which corresponds to the standard infinitesimal generator.

Returns
List[Callable]

List of generator functions in ascending power order. Formally, this function returns \([\phi, \mathcal{A}\phi, \ldots, \mathcal{A}^p\phi]\), where p is the order. Each callable function in this list has exactly the same input-output shape signature as phi: (d,) -> (…).

Notes

The implementation is due to Adrien Corenflos. Thank you for contributing this.

You may also find a naive implementation of infinitesimal generators and their iterations in the test file ./test/test_tme_jax.py.

Return type

List[Callable]

tme.base_jax.mean_and_cov(x, dt, drift, dispersion, order=3)[source]

TME approximation for mean and covariance.

For math details, see the docstring of tme.base_sympy.mean_and_cov().

Parameters
xjnp.ndarray (d,)

The state at which the generator is evaluated. (i.e., the \(x\) in \(\mathbb{E}[X(t + \Delta t) \mid X(t)=x]\) and \(\mathrm{Cov}[X(t + \Delta t) \mid X(t)=x]\)).

dtfloat

Time interval.

driftCallable (d,) -> (d,)

SDE drift coefficient.

dispersionCallable (d,) -> (d, w)

SDE dispersion coefficient, where w stands for the dimension of the Wiener process.

orderint, default=3

Order of TME. Must be >= 1.

Returns
mjnp.ndarray (d,)

TME approximation of mean \(\mathbb{E}[X(t + \Delta t) \mid X(t)=x]\).

covjnp.ndarray (d, d)

TME approximation of covariance \(\mathrm{Cov}[X(t + \Delta t) \mid X(t)=x]\).

Notes

When order = 1, the TME mean and cov approximations are exactly the same with Euler–Maruyama.

Return type

Tuple[ndarray, ndarray]