propagate_cov#

anri.fwd.propagate_cov(J_func_out, cov_in)[source]#

Propagate an input covariance matrix with a Jacobian to yield an output covariance matrix.

Parameters:
  • J_func_out (Iterable[Array]) – The output of calling jax.jacfwd() on a JAX jitted function.

  • cov_in (Array) – [6,6] The input covariance matrix - build with get_cov_in(). Must have the same dimensionality as J_func_out

Returns:

cov_out (jax.Array) – [3,3] Output covariance matrix - the covariance in the outputs of the JAX jitted function

Notes

Propagation goes as: \(\mathbf{\Sigma}^{\text{out}} = \mathbf {J_{f}} \mathbf{\Sigma}^{\text{in}} \mathbf {J_{f}}^T\)

This handles multi-dimensional outputs - e.g. if one of the function inputs is a 3-vector, we get a 3x3 Jacobian for it.