{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Forward model with errors\n", "As a forward model is basically just a change-of-basis ($(h,k,l)$ space to detector space), we can propagate an error in the change-of-basis matrix (such as an error in wavelength) through the forward model to understand how our peak positions change as a function of our input parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"0\"\n", "\n", "import jax\n", "jax.config.update(\"jax_enable_x64\", True)\n", "import jax.numpy as jnp\n", "from jax.scipy.spatial.transform import Rotation as jR\n", "from matplotlib import pyplot as plt\n", "\n", "import anri\n", "\n", "import time\n", "\n", "start = time.time()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Crystallography\n", "Let's do a quick and dirty generation of a single peak to forward model. We'll choose BCC iron again:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "struc = anri.crystal.Structure.from_cif(\"../../../tests/data/cif/Fe.cif\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# The central value of wavelength\n", "wavelength_cen = 0.2\n", "struc.make_hkls(dsmax=0.6, wavelength=wavelength_cen)\n", "# just take the first peak:\n", "hkl = struc.ringhkls_arr[0]\n", "print(hkl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll generate a single grain at the origin with a known orientation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "U = jnp.eye(3)\n", "UB = U @ struc.B\n", "UBI = jnp.linalg.inv(UB)\n", "print(UBI)\n", "origin_sample = jnp.array([0., 0., 0.])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Experiment parameters\n", "In this simple experiment, all our parameters are constant:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# goniometer\n", "chi = 0.0\n", "wedge = 0.0\n", "dty = 0.0\n", "y0 = 0.0\n", "# detector\n", "y_center = 1024.0\n", "z_center = 1024.0\n", "y_size = 75.0\n", "z_size = 75.0\n", "tilt_x = 0.0\n", "tilt_y = 0.0\n", "tilt_z = 0.0\n", "distance = 300e3\n", "o11 = 1\n", "o12 = 0\n", "o21 = 0\n", "o22 = 1\n", "detector_size = 2048 # px" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "det_trans, beam_cen_shift, x_distance_shift = anri.geom.detector_transforms(\n", "y_center,\n", "y_size,\n", "tilt_y,\n", "z_center,\n", "z_size,\n", "tilt_z,\n", "tilt_x,\n", "distance,\n", "o11,\n", "o12,\n", "o21,\n", "o22\n", ")\n", "\n", "sc_lab, fc_lab, norm_lab = anri.geom.detector_basis_vectors_lab(det_trans, beam_cen_shift, x_distance_shift)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's define a simple function in JAX to forward model a single HKL with variable wavelength:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def simulate_one_peak(ubi, hkl, origin, wavelength):\n", " # origin: sample frame\n", " q_sample = jnp.linalg.inv(ubi) @ hkl\n", "\n", " k_in_lab = jnp.array([1., 0., 0.])\n", " k_in_lab_norm = anri.diffract.scale_norm_k(k_in_lab, wavelength)\n", " k_in_sample_norm = anri.geom.lab_to_sample(k_in_lab_norm, 0.0, wedge, chi, dty, y0)\n", "\n", " omega, valid1 = anri.diffract.omega_solns(q_sample, 1.0, k_in_sample_norm)\n", "\n", " q_lab = anri.geom.sample_to_lab(q_sample, omega, wedge, chi, dty, y0)\n", "\n", " _, eta = anri.diffract.q_lab_to_tth_eta(q_lab, wavelength)\n", " jax.debug.print(\"eta (deg): {eta}\", eta=eta)\n", "\n", " origin_lab = anri.geom.sample_to_lab(origin, omega, wedge, chi, dty, y0)\n", "\n", " k_out = anri.diffract.q_lab_to_k_out(q_lab, k_in_lab_norm)\n", "\n", " sc, fc = anri.geom.raytrace_to_det(k_out, origin_lab, sc_lab, fc_lab, norm_lab)\n", "\n", " return jnp.array([sc, fc, omega])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sc, fc, omega = simulate_one_peak(UBI, hkl, origin_sample, wavelength_cen)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig ,ax = plt.subplots()\n", "ax.scatter(fc, sc, label='peak')\n", "ax.scatter(y_center, z_center,label='beam center')\n", "ax.set(ylim=(0, detector_size), xlim=(0, detector_size), aspect=1, title='Forward model into detector space', xlabel='fc', ylabel='sc')\n", "ax.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Understanding error propagation\n", "As someone who never covered this in detail university, it's helpful to understand an extremely simple example. \n", "Let's define a function:\n", "\n", "$\\mathbf {f} :\\mathbb {R} ^{n}\\to \\mathbb {R} ^{m}$\n", "\n", "This notation means that the function $\\mathbf{f}$ takes a length-$n$ vector of real numbers and returns another length-$m$ vector of real numbers.\n", "\n", "We can represent the input as some length-$n$ vector:\n", "\n", "$\\mathbf {x} =(x_{1},\\ldots ,x_{n})\\in \\mathbb {R} ^{n}$\n", "\n", "and the output as a length-$m$ vector:\n", "\n", "$\\mathbf {f} (\\mathbf {x} )=(f_{1}(\\mathbf {x} ),\\ldots ,f_{m}(\\mathbf {x} ))\\in \\mathbb {R} ^{m}$\n", "\n", "We can then define the Jacobian matrix $\\mathbf {J_{f}}$:\n", "\n", "$\\mathbf {J_{f}} ={\\begin{bmatrix}{\\dfrac {\\partial \\mathbf {f} }{\\partial x_{1}}}&\\cdots &{\\dfrac {\\partial \\mathbf {f} }{\\partial x_{n}}}\\end{bmatrix}}={\\begin{bmatrix}\\nabla ^{\\mathsf {T}}f_{1}\\\\\\vdots \\\\\\nabla ^{\\mathsf {T}}f_{m}\\end{bmatrix}}={\\begin{bmatrix}{\\dfrac {\\partial f_{1}}{\\partial x_{1}}}&\\cdots &{\\dfrac {\\partial f_{1}}{\\partial x_{n}}}\\\\\\vdots &\\ddots &\\vdots \\\\{\\dfrac {\\partial f_{m}}{\\partial x_{1}}}&\\cdots &{\\dfrac {\\partial f_{m}}{\\partial x_{n}}}\\end{bmatrix}}$\n", "\n", "The above means that $\\mathbf {J_{f}}$ encodes the partial derivatives of each value of the output vector $\\mathbf {f} (\\mathbf {x} )$ with respect to each value of the input vector $\\mathbf {x}$.\n", "\n", "As a very basic example:\n", "\n", "$\\mathbf {f}(x, y) = \\mathbf (x^2, y^2)$\n", "\n", "The Jacobian would then be:\n", "\n", "$\\mathbf {J_{f}} = {\\begin{bmatrix}{\\dfrac {\\partial f_{1}}{\\partial x_{1}}} &{\\dfrac {\\partial f_{1}}{\\partial x_{2}}}\\\\ {\\dfrac {\\partial f_{2}}{\\partial x_{1}}} &{\\dfrac {\\partial f_{2}}{\\partial x_{2}}}\\end{bmatrix}}$\n", "\n", "For $(x_1, x_2) = (x, y)$ and $(f_1, f_2) = (x^2, y^2)$ we then get:\n", "\n", "$\\mathbf {J_{f}} = {\\begin{bmatrix}{\\dfrac {\\partial x^2}{\\partial x}} & {\\dfrac {\\partial x^2}{\\partial y}}\\\\ {\\dfrac {\\partial y^2}{\\partial x}} &{\\dfrac {\\partial y^2}{\\partial y}}\\end{bmatrix}} = \\begin{bmatrix}2x & 0 \\\\ 0 & 2y\\end{bmatrix}$\n", "\n", "Let's now see how to implement this in Jax:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def f(vec_in):\n", " vec_out = jnp.power(vec_in, 2)\n", " return vec_out\n", "\n", "vec_in = jnp.array([2.0, 3.0])\n", "vec_out = f(vec_in)\n", "print(vec_out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can manually define the Jacobian:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def jac_func_manual(vec_in):\n", " vec_out = jnp.array([\n", " [2*vec_in[0], 0],\n", " [0, 2*vec_in[1]]\n", " ])\n", " return vec_out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we define a function in JAX that can give us the Jacobian:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "jac_func = jax.jacfwd(f)\n", "\n", "J = jac_func(vec_in)\n", "\n", "print(J)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we can check that our manual Jacobian function agrees:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "J_manual = jac_func_manual(vec_in)\n", "\n", "print(J_manual)\n", "\n", "assert jnp.allclose(J, J_manual)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's assume we have some uncertainty in our inputs.\n", "\n", "We can generate a variance-covariance matrix on our inputs:\n", "\n", "$\\mathbf{\\Sigma}^{\\text{in}} = \\begin{pmatrix} \\sigma_x^2 & \\sigma_{xy} \\\\ \\sigma_{yx} & \\sigma_y^2 \\end{pmatrix}$\n", "\n", "We can propagate our errors in the input vector to the errors in the output vector:\n", "\n", "$\\mathbf{\\Sigma}^{\\text{out}} = \\mathbf {J_{f}} \\mathbf{\\Sigma}^{\\text{in}} \\mathbf {J_{f}}^T$\n", "\n", "We can convince ourselves of this via JAX:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sig_x = 1.0\n", "sig_y = 2.0\n", "# assume linearly independent:\n", "cov_in = jnp.array([[sig_x**2, 0], [0, sig_y**2]])\n", "\n", "cov_out = J @ cov_in @ J.T\n", "\n", "print(cov_out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For a simple function $f=A^b$, the variance in $f$, $\\sigma_f^2$ goes as:\n", "\n", "$\\sigma_f^2 = \\left(\\frac{f b \\sigma_A}{A}\\right)^2$\n", "\n", "We can check that our output covariance matrix matches:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert jnp.allclose(cov_out[0,0], (vec_out[0] * 2 * sig_x / vec_in[0])**2)\n", "\n", "assert jnp.allclose(cov_out[1,1], (vec_out[1] * 2 * sig_y / vec_in[1])**2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have now shown that we can use JAX to propagate errors in parameters through arbitrary functions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Forward propagating errors\n", "\n", "To keep things simple-ish, let's assume there are two sources of error:\n", "- `origin_sample` - we can move the grain around in space. This shouldn't change the Bragg angle, but it will change the peak position on the detector linearly.\n", "- `wavelength` - this will change all of `(sc, fc omega)`\n", "\n", "For now, we'll say the wavelength is perfect, just to inspect the effect of the error in grain position.\n", "\n", "First, we use JAX to get a function that yields the Jacobian of our forward simulation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "jac_func = jax.jacfwd(simulate_one_peak, argnums=(2,3))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we call the function on the peak position to get the actual Jacobian:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "J_origin, J_wav = jac_func(UBI, hkl, origin_sample, wavelength_cen)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we are investigating two parameters, we get two Jacobians. We can combine them together into a single 4x4 matrix:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "J_wav = J_wav.reshape(3,1)\n", "J_total = jnp.block([J_origin, J_wav])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we write a quick function that gives us the covariance matrices given some observed standard deviations in our experimental parameters:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from jax.scipy.linalg import block_diag\n", "\n", "def get_cov(sig_wav, sig_origin):\n", " cov_beam = jnp.array([[sig_wav**2]])\n", " cov_origin = jnp.eye(3) * (sig_origin**2)\n", "\n", " # combine together into one covariance matrix:\n", " cov_total = block_diag(cov_origin, cov_beam)\n", " return cov_total" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now with the Jacobian and covariance matrices, we can combine them together to project the covariance matrix into detector space:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Standard deviations for beam\n", "sig_wav = 0.0 # same units as wavelength_cen\n", "# Standard deviations for grain position (assume isotropic in all three directions):\n", "sig_origin = 75.0\n", "\n", "cov_total = get_cov(sig_wav, sig_origin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cov_det = J_total @ cov_total @ J_total.T\n", "print('Covariance matrix in detector space:')\n", "print(cov_det)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now investigate the result. As we specified the standard deviation of the origin to be equal to the pixel size, we can expect the standard deviations on `(sc, fc)` to be approximately 1 pixel, and for the Bragg angle to be unaffected:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sig_sc, sig_fc, sig_omega = jnp.sqrt(jnp.diag(cov_det))\n", "print('(sig_sc, sig_fc, sig_omega):')\n", "print(sig_sc, sig_fc, sig_omega)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that the errors in detector space are close to 1, but not exactly. This is probably due to the assumption in linearity imposed by the Jacobian-covariance matrix product we used earlier.\n", "\n", "Now let's introduce variable wavelength:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Standard deviations for beam\n", "sig_wav = 1e-5 # more realistic for ID11\n", "# Standard deviations for grain position (assume isotropic in all three directions):\n", "sig_origin = 0.0\n", "\n", "cov_total = get_cov(sig_wav, sig_origin)\n", "\n", "cov_det = J_total @ cov_total @ J_total.T\n", "print('Covariance matrix in detector space:')\n", "print(cov_det)\n", "\n", "sig_sc, sig_fc, sig_omega = jnp.sqrt(jnp.diag(cov_det))\n", "print('(sig_sc, sig_fc, sig_omega):')\n", "print(sig_sc, sig_fc, sig_omega)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that a small change in the wavelength yields a very small change in both detector pixel position and the Bragg angle. \n", "Because we simulated a scattering vector with $\\eta = 45^\\circ$, we also expect the changes in `sc` and `fc` to be equal as we change the wavelength.\n", "Combining both together with more realistic numbers:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Standard deviations for beam\n", "sig_wav = 1e-5 # more realistic for ID11\n", "# Standard deviations for grain position (assume isotropic in all three directions):\n", "sig_origin = 10.0\n", "\n", "cov_total = get_cov(sig_wav, sig_origin)\n", "\n", "cov_det = J_total @ cov_total @ J_total.T\n", "print('Covariance matrix in detector space:')\n", "print(cov_det)\n", "\n", "sig_sc, sig_fc, sig_omega = jnp.sqrt(jnp.diag(cov_det))\n", "print('(sig_sc, sig_fc, sig_omega):')\n", "print(sig_sc, sig_fc, sig_omega)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So we can see that a variance in the wavelength of around $10^{-5}$ and a variance in grain position of around 10 µm changes the pixel position by around 0.13 pixels with this current detector setup." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 }