"""
Taylor moment expansion (TME) in JaX.
For math details, please see the docstring of :py:mod:`tme.base_sympy`.
Functions
---------
:py:func:`generator`
Infinitesimal generator. This is a helper function around :py:func:`generator_power`.
:py:func:`generator_power`
Iterations/power of infinitesimal generators.
:py:func:`mean_and_cov`
TME approximation for mean and covariance. In case you just want to compute the mean, use function
:py:func:`expectation` with argument :code:`phi` fed by an identity function.
:py:func:`expectation`
TME approximation for any expectation of the form :math:`\mathbb{E}[\phi(X(t + \Delta t)) \mid X(t)]`.
References
----------
See the docstring of :py:mod:`tme.base_sympy`.
Authors
-------
Adrien Corenflos and Zheng Zhao, 2021
"""
# TODO: The logic further down can be improved dramatically if we make diagonal noise specific logic.
try:
import jax as _
except:
raise ImportError("By default the library is not packaged with JaX due to the need to support CPU and GPU users. "
"In order to use it, follow the instructions on https://github.com/google/jax#installation.")
import math
from math import factorial
from typing import Callable, List, Tuple
import jax.numpy as jnp
from jax import jvp, linearize, vmap
__all__ = ['generator',
'generator_power',
'mean_and_cov',
'expectation']
[docs]def generator_power(phi: Callable, drift: Callable, dispersion: Callable, order: int = 1) -> List[Callable]:
r"""Iterations/power of infinitesimal generator.
For math details, see the docstring of :py:func:`tme.base_sympy.generator_power`.
Parameters
----------
phi : Callable (d,) -> (...)
Target function.
drift : Callable (d,) -> (d,)
SDE drift coefficient.
dispersion : Callable (d,) -> (d, w)
SDE dispersion coefficient, where `w` stands for the dimension of the Wiener process.
order : int, optional
Number of generator iterations. Must be >=0. Default is 1, which corresponds to the standard infinitesimal
generator.
Returns
-------
List[Callable]
List of generator functions in ascending power order. Formally, this function returns
:math:`[\phi, \mathcal{A}\phi, \ldots, \mathcal{A}^p\phi]`, where :code:`p` is the order.
Each callable function in this list has exactly the same input-output shape
signature as phi: (d,) -> (...).
Notes
-----
The implementation is due to Adrien Corenflos. Thank you for contributing this.
You may also find a naive implementation of infinitesimal generators and their iterations in the test file
:code:`./test/test_tme_jax.py`.
"""
def jac_part(z, f):
return jvp(f, (z,), (drift(z),))[1]
def hess_prod_1(z, f, b):
_, linearized_f = linearize(f, z)
return vmap(linearized_f, in_axes=1, out_axes=0)(b)
def hess_prod_2(z, f):
b = _format_dispersion(dispersion(z))
_, linearized_f = linearize(lambda zz: hess_prod_1(zz, f, b), z)
return vmap(linearized_f, in_axes=0, out_axes=1)(b.T)
gen_power = phi
list_of_gen_powers = [gen_power]
for _ in range(order):
def gen_power(z, f=gen_power):
return jac_part(z, f) + 0.5 * jnp.einsum("ii...", hess_prod_2(z, f))
list_of_gen_powers.append(gen_power)
return list_of_gen_powers
[docs]def generator(phi: Callable, drift: Callable, dispersion: Callable) -> Callable:
r"""Infinitesimal generator for diffusion processes in Ito's SDE constructions.
.. math::
(\mathcal{A}\phi)(x) = \sum^d_{i=1} a_i(x)\,\frac{\partial \phi}{\partial x_i}(x)
+ \frac{1}{2}\, \sum^d_{i,j=1} \Gamma_{ij}(x) \, \frac{\partial^2 \phi}{\partial x_i \, \partial x_j}(x),
where :math:`\phi\colon \mathbb{R}^d \to \mathbb{R}` must be sufficiently smooth function depending on the
expansion order, and :math:`\Gamma(x) = b(x) \, b(x)^\top`.
This is a helper function around :py:func:`generator_power`.
Parameters
----------
phi : Callable (d,) -> (...)
Target function.
drift : Callable (d,) -> (d,)
SDE drift coefficient.
dispersion : Callable (d,) -> (d, w)
SDE dispersion coefficient, where `w` stands for the dimension of the Wiener process.
Returns
-------
Callable (...)
A callable function which carries out :math:`x \mapsto \mathcal{A}\phi`. The output shape of this function
is the same as :code:`phi`.
"""
return generator_power(phi, drift, dispersion, 1)[1]
[docs]def mean_and_cov(x: jnp.ndarray, dt: float, drift: Callable, dispersion: Callable,
order: int = 3) -> Tuple[jnp.ndarray, jnp.ndarray]:
r"""TME approximation for mean and covariance.
For math details, see the docstring of :py:func:`tme.base_sympy.mean_and_cov`.
Parameters
----------
x : jnp.ndarray (d,)
The state at which the generator is evaluated. (i.e., the :math:`x` in
:math:`\mathbb{E}[X(t + \Delta t) \mid X(t)=x]` and :math:`\mathrm{Cov}[X(t + \Delta t) \mid X(t)=x]`).
dt : float
Time interval.
drift : Callable (d,) -> (d,)
SDE drift coefficient.
dispersion : Callable (d,) -> (d, w)
SDE dispersion coefficient, where `w` stands for the dimension of the Wiener process.
order : int, default=3
Order of TME. Must be >= 1.
Returns
-------
m : jnp.ndarray (d,)
TME approximation of mean :math:`\mathbb{E}[X(t + \Delta t) \mid X(t)=x]`.
cov : jnp.ndarray (d, d)
TME approximation of covariance :math:`\mathrm{Cov}[X(t + \Delta t) \mid X(t)=x]`.
Notes
-----
When `order = 1`, the TME mean and cov approximations are exactly the same with Euler--Maruyama.
"""
# Give generator powers of phi^I and phi^II then evaluate them all
list_of_Aphi_i = generator_power(lambda z: z, drift, dispersion, order)
list_of_Aphi_ii = generator_power(lambda z: jnp.outer(z, z), drift, dispersion, order)
Aphi_i_powers = [func(x) for func in list_of_Aphi_i]
Aphi_ii_powers = [func(x) for func in list_of_Aphi_ii]
# Give the mean approximation
m = x
for r in range(1, order + 1):
m = m + 1 / factorial(r) * Aphi_i_powers[r] * dt ** r
# Give the cov approximation
# r = 1
cov = Aphi_ii_powers[1] - jnp.outer(Aphi_i_powers[0], Aphi_i_powers[1]) \
- jnp.outer(Aphi_i_powers[1], Aphi_i_powers[0])
cov = cov * dt
for r in range(2, order + 1):
coeff = Aphi_ii_powers[r]
for k in range(r + 1):
coeff = coeff - _comb(r, k) * jnp.outer(Aphi_i_powers[k], Aphi_i_powers[r - k])
cov = cov + 1 / factorial(r) * coeff * dt ** r
return m, cov
[docs]def expectation(phi: Callable, x: jnp.ndarray, dt: float, drift: Callable, dispersion: Callable,
order: int = 3) -> jnp.ndarray:
r"""TME approximation of expectation on any target function :math:`\phi`.
For math details, see the docstring of :py:func:`tme.base_sympy.expectation`.
Parameters
----------
phi : Callable (d,) -> (...)
Target function (must be sufficiently smooth depending on the order).
x : jnp.ndarray (d, )
The state at which the generator is evaluated (i.e., the :math:`x` in
:math:`\mathbb{E}[\phi(X(t + \Delta t)) \mid X(t)=x]`).
dt : float
Time interval.
drift : Callable (d,) -> (d,)
SDE drift coefficient.
dispersion : Callable (d,) -> (d, w)
SDE dispersion coefficient, where `w` stands for the dimension of the Wiener process.
order : int
Order of TME. Must be >=0. For the relationship between the expansion order and SDE coefficient smoothness, see,
Zhao (2021).
Returns
-------
jnp.ndarray (...)
TME approximation of :math:`\mathbb{E}[\phi(X(t + \Delta t)) \mid X(t)]`. The output shape is consistence with
the input shape of :code:`phi`.
"""
list_of_Aphi = generator_power(phi, drift, dispersion, order)
Aphi = phi(x)
for r in range(1, order + 1):
Aphi += 1 / factorial(r) * list_of_Aphi[r](x) * dt ** r
return Aphi
def _comb(n, k):
try:
return math.comb(n, k)
except AttributeError: # Python version < 3.8 does not have math.comb
return _manual_comb(n, k)
def _manual_comb(n, k):
if k > n // 2:
return _manual_comb(n, n - k)
if k < 0:
return 0
if k == 0:
return 1
return _manual_comb(n - 1, k - 1) + _manual_comb(n - 1, k)
def _format_dispersion(bz):
ndim = jnp.ndim(bz)
if ndim == 0:
return jnp.atleast_2d(bz)
if ndim == 1:
return jnp.expand_dims(bz, 1)
if ndim == 2:
return bz
else:
raise ValueError(f"Dispersion coefficient b(z) must have at most 2 dimensions. {ndim} were passed")