TME in JaX
Taylor moment expansion (TME) in JaX.
For math details, please see the docstring of tme.base_sympy
.
Functions
generator()
Infinitesimal generator. This is a helper function around
generator_power()
.generator_power()
Iterations/power of infinitesimal generators.
mean_and_cov()
TME approximation for mean and covariance. In case you just want to compute the mean, use function
expectation()
with argumentphi
fed by an identity function.expectation()
TME approximation for any expectation of the form \(\mathbb{E}[\phi(X(t + \Delta t)) \mid X(t)]\).
References
See the docstring of tme.base_sympy
.
- tme.base_jax.expectation(phi, x, dt, drift, dispersion, order=3)[source]
TME approximation of expectation on any target function \(\phi\).
For math details, see the docstring of
tme.base_sympy.expectation()
.- Parameters
- phiCallable (d,) -> (…)
Target function (must be sufficiently smooth depending on the order).
- xjnp.ndarray (d, )
The state at which the generator is evaluated (i.e., the \(x\) in \(\mathbb{E}[\phi(X(t + \Delta t)) \mid X(t)=x]\)).
- dtfloat
Time interval.
- driftCallable (d,) -> (d,)
SDE drift coefficient.
- dispersionCallable (d,) -> (d, w)
SDE dispersion coefficient, where w stands for the dimension of the Wiener process.
- orderint
Order of TME. Must be >=0. For the relationship between the expansion order and SDE coefficient smoothness, see, Zhao (2021).
- Returns
- jnp.ndarray (…)
TME approximation of \(\mathbb{E}[\phi(X(t + \Delta t)) \mid X(t)]\). The output shape is consistence with the input shape of
phi
.
- Return type
ndarray
- tme.base_jax.generator(phi, drift, dispersion)[source]
Infinitesimal generator for diffusion processes in Ito’s SDE constructions.
\[(\mathcal{A}\phi)(x) = \sum^d_{i=1} a_i(x)\,\frac{\partial \phi}{\partial x_i}(x) + \frac{1}{2}\, \sum^d_{i,j=1} \Gamma_{ij}(x) \, \frac{\partial^2 \phi}{\partial x_i \, \partial x_j}(x),\]where \(\phi\colon \mathbb{R}^d \to \mathbb{R}\) must be sufficiently smooth function depending on the expansion order, and \(\Gamma(x) = b(x) \, b(x)^\top\).
This is a helper function around
generator_power()
.- Parameters
- phiCallable (d,) -> (…)
Target function.
- driftCallable (d,) -> (d,)
SDE drift coefficient.
- dispersionCallable (d,) -> (d, w)
SDE dispersion coefficient, where w stands for the dimension of the Wiener process.
- Returns
- Callable (…)
A callable function which carries out \(x \mapsto \mathcal{A}\phi\). The output shape of this function is the same as
phi
.
- Return type
Callable
- tme.base_jax.generator_power(phi, drift, dispersion, order=1)[source]
Iterations/power of infinitesimal generator.
For math details, see the docstring of
tme.base_sympy.generator_power()
.- Parameters
- phiCallable (d,) -> (…)
Target function.
- driftCallable (d,) -> (d,)
SDE drift coefficient.
- dispersionCallable (d,) -> (d, w)
SDE dispersion coefficient, where w stands for the dimension of the Wiener process.
- orderint, optional
Number of generator iterations. Must be >=0. Default is 1, which corresponds to the standard infinitesimal generator.
- Returns
- List[Callable]
List of generator functions in ascending power order. Formally, this function returns \([\phi, \mathcal{A}\phi, \ldots, \mathcal{A}^p\phi]\), where
p
is the order. Each callable function in this list has exactly the same input-output shape signature as phi: (d,) -> (…).
Notes
The implementation is due to Adrien Corenflos. Thank you for contributing this.
You may also find a naive implementation of infinitesimal generators and their iterations in the test file
./test/test_tme_jax.py
.- Return type
List
[Callable
]
- tme.base_jax.mean_and_cov(x, dt, drift, dispersion, order=3)[source]
TME approximation for mean and covariance.
For math details, see the docstring of
tme.base_sympy.mean_and_cov()
.- Parameters
- xjnp.ndarray (d,)
The state at which the generator is evaluated. (i.e., the \(x\) in \(\mathbb{E}[X(t + \Delta t) \mid X(t)=x]\) and \(\mathrm{Cov}[X(t + \Delta t) \mid X(t)=x]\)).
- dtfloat
Time interval.
- driftCallable (d,) -> (d,)
SDE drift coefficient.
- dispersionCallable (d,) -> (d, w)
SDE dispersion coefficient, where w stands for the dimension of the Wiener process.
- orderint, default=3
Order of TME. Must be >= 1.
- Returns
- mjnp.ndarray (d,)
TME approximation of mean \(\mathbb{E}[X(t + \Delta t) \mid X(t)=x]\).
- covjnp.ndarray (d, d)
TME approximation of covariance \(\mathrm{Cov}[X(t + \Delta t) \mid X(t)=x]\).
Notes
When order = 1, the TME mean and cov approximations are exactly the same with Euler–Maruyama.
- Return type
Tuple
[ndarray
,ndarray
]