{ "cells": [ { "cell_type": "markdown", "id": "75303862", "metadata": {}, "source": [ "# Simple forward model\n", "Now we have some crystallographic functions and we can handle the detector geometry, we can perform a basic forward model of a single crystal to reassure ourselves that this wasn't all for nothing!" ] }, { "cell_type": "code", "execution_count": null, "id": "ce20acf9", "metadata": {}, "outputs": [], "source": [ "import jax\n", "jax.config.update(\"jax_enable_x64\", True)\n", "import jax.numpy as jnp\n", "from jax.scipy.spatial.transform import Rotation as jR\n", "from matplotlib import pyplot as plt\n", "\n", "import anri\n", "\n", "import time\n", "\n", "start = time.time()" ] }, { "cell_type": "markdown", "id": "bb412121", "metadata": {}, "source": [ "In `Anri`, all fundamental functions and transforms are written for single vectors. \n", "This was written to significantly simplify the functions themselves, keeping them easy to understand. \n", "Additonally, when forward simulating many grains or voxels, you will likely have more complicated array shapes, so it makes sense to leave the broadcasting to the user or another part of the program for now. \n", "I'm currently still grappling with the best way to expose vmapped functions in the API, so for now I will manually declare them here:" ] }, { "cell_type": "code", "execution_count": 2, "id": "aa3b1501", "metadata": {}, "outputs": [], "source": [ "# easy example: many hkls, single B matrix, so we vmap over hkls only, giving us [0, None]\n", "omega_solns_vec = jax.vmap(anri.diffract.omega_solns, in_axes=[0, None, None])\n", "sample_to_lab_vec = jax.vmap(anri.geom.sample_to_lab, in_axes=[0, 0, None, None, None, None])\n", "q_lab_to_k_out_vec = jax.vmap(anri.diffract.q_lab_to_k_out, in_axes=[0, None])\n", "raytrace_to_det_vec = jax.vmap(anri.geom.raytrace_to_det, in_axes=[0, None, None, None, None])\n", "q_lab_to_tth_eta_vec = jax.vmap(anri.diffract.q_lab_to_tth_eta, in_axes=[0, None])" ] }, { "cell_type": "markdown", "id": "76f1dbc4", "metadata": {}, "source": [ "## Crystallography" ] }, { "cell_type": "markdown", "id": "f799e94f", "metadata": {}, "source": [ "Let's take a simple Fe CIF file" ] }, { "cell_type": "code", "execution_count": null, "id": "ae8d4d5b", "metadata": {}, "outputs": [], "source": [ "struc = anri.crystal.Structure.from_cif(\"../../../tests/data/cif/Fe.cif\")" ] }, { "cell_type": "markdown", "id": "9af46d0a", "metadata": {}, "source": [ "We generate some hkls:" ] }, { "cell_type": "code", "execution_count": 4, "id": "270ebc9a", "metadata": {}, "outputs": [], "source": [ "dsmax = 2.0\n", "wavelength = 0.3\n", "struc.make_hkls(dsmax=dsmax, wavelength=wavelength)" ] }, { "cell_type": "code", "execution_count": 5, "id": "d55e03b6", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "shape: (12, 7)
hkltthdsintensityring_id
i64i64i64f64f64f64u32
-1108.4982660.4939561363.4480820
-1018.4982660.4939561363.4480820
-10-18.4982660.4939561363.4480820
0-1-18.4982660.4939561363.4480820
0-118.4982660.4939561363.4480820
10-18.4982660.4939561363.4480820
-1-108.4982660.4939561363.4480820
1018.4982660.4939561363.4480820
1108.4982660.4939561363.4480820
1-108.4982660.4939561363.4480820
" ], "text/plain": [ "shape: (12, 7)\n", "┌─────┬─────┬─────┬──────────┬──────────┬─────────────┬─────────┐\n", "│ h ┆ k ┆ l ┆ tth ┆ ds ┆ intensity ┆ ring_id │\n", "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ i64 ┆ i64 ┆ i64 ┆ f64 ┆ f64 ┆ f64 ┆ u32 │\n", "╞═════╪═════╪═════╪══════════╪══════════╪═════════════╪═════════╡\n", "│ -1 ┆ 1 ┆ 0 ┆ 8.498266 ┆ 0.493956 ┆ 1363.448082 ┆ 0 │\n", "│ -1 ┆ 0 ┆ 1 ┆ 8.498266 ┆ 0.493956 ┆ 1363.448082 ┆ 0 │\n", "│ -1 ┆ 0 ┆ -1 ┆ 8.498266 ┆ 0.493956 ┆ 1363.448082 ┆ 0 │\n", "│ 0 ┆ -1 ┆ -1 ┆ 8.498266 ┆ 0.493956 ┆ 1363.448082 ┆ 0 │\n", "│ 0 ┆ -1 ┆ 1 ┆ 8.498266 ┆ 0.493956 ┆ 1363.448082 ┆ 0 │\n", "│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n", "│ 1 ┆ 0 ┆ -1 ┆ 8.498266 ┆ 0.493956 ┆ 1363.448082 ┆ 0 │\n", "│ -1 ┆ -1 ┆ 0 ┆ 8.498266 ┆ 0.493956 ┆ 1363.448082 ┆ 0 │\n", "│ 1 ┆ 0 ┆ 1 ┆ 8.498266 ┆ 0.493956 ┆ 1363.448082 ┆ 0 │\n", "│ 1 ┆ 1 ┆ 0 ┆ 8.498266 ┆ 0.493956 ┆ 1363.448082 ┆ 0 │\n", "│ 1 ┆ -1 ┆ 0 ┆ 8.498266 ┆ 0.493956 ┆ 1363.448082 ┆ 0 │\n", "└─────┴─────┴─────┴──────────┴──────────┴─────────────┴─────────┘" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "struc.rings_dict[0]" ] }, { "cell_type": "markdown", "id": "9bb8ebde", "metadata": {}, "source": [ "Let's generate a random orientation." ] }, { "cell_type": "code", "execution_count": null, "id": "45623097", "metadata": {}, "outputs": [], "source": [ "key = jax.random.key(time.time_ns())\n", "random_euler = jax.random.uniform(key, shape=(3,), minval=-90.0, maxval=90.0)\n", "U = jR.from_euler('XYZ', random_euler, degrees=True).as_matrix()\n", "U" ] }, { "cell_type": "markdown", "id": "d157cbde", "metadata": {}, "source": [ "Now we can generate some scattering vectors in the sample frame:" ] }, { "cell_type": "code", "execution_count": null, "id": "51930545", "metadata": {}, "outputs": [], "source": [ "UB = U @ struc.B\n", "q_sample = (UB @ struc.ringhkls_arr.T).T\n", "q_sample.shape" ] }, { "cell_type": "markdown", "id": "68b4760e", "metadata": {}, "source": [ "## Ewald condition" ] }, { "cell_type": "markdown", "id": "f3a6de56", "metadata": {}, "source": [ "Now we can determine the omega angles required to diffract:" ] }, { "cell_type": "code", "execution_count": null, "id": "84c2cb6b", "metadata": {}, "outputs": [], "source": [ "chi = 0.0\n", "wedge = 0.0\n", "dty = 0.0\n", "y0 = 0.0\n", "\n", "# define incoming wavevector in the lab frame\n", "k_in_lab = jnp.array([1., 0, 0])\n", "k_in_lab_norm = anri.diffract.scale_norm_k(k_in_lab, wavelength)\n", "\n", "# map it into the sample frame\n", "k_in_sample_norm = anri.geom.lab_to_sample(k_in_lab_norm, 0.0, wedge, chi, dty, y0)\n", "# etasign +1:\n", "omega1, valid1 = omega_solns_vec(q_sample, 1.0, k_in_sample_norm)\n", "# etasign -1:\n", "omega2, valid2 = omega_solns_vec(q_sample, -1.0, k_in_sample_norm)\n", "omega = jnp.concatenate([omega1, omega2])\n", "valid = jnp.concatenate([valid1, valid2])\n", "q_sample = jnp.concatenate([q_sample, q_sample])\n", "omega_valid = omega[valid]\n", "q_sample_valid = q_sample[valid]" ] }, { "cell_type": "markdown", "id": "b2d6c163", "metadata": {}, "source": [ "## Into the lab frame\n", "With the omega angles determined, we can rotate `q_sample` into the lab frame:" ] }, { "cell_type": "code", "execution_count": null, "id": "5e5860b4", "metadata": {}, "outputs": [], "source": [ "q_lab = sample_to_lab_vec(q_sample_valid, omega_valid, wedge, chi, dty, y0)" ] }, { "cell_type": "code", "execution_count": null, "id": "523c42eb", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(8,8))\n", "ax.scatter(q_lab[:, 1], q_lab[:, 2])\n", "ax.set_aspect(1)\n", "ax.set(xlabel='Lab Y', ylabel='Lab Z')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "ee477db3", "metadata": {}, "source": [ "## Into the detector" ] }, { "cell_type": "markdown", "id": "22a17a76", "metadata": {}, "source": [ "Now we can forward-project them into the detector!\n", "\n", "Let's set up the detector transforms:" ] }, { "cell_type": "code", "execution_count": null, "id": "2ccc1e6d", "metadata": {}, "outputs": [], "source": [ "y_center = 1000.0\n", "z_center = 1000.0\n", "y_size = 75.0\n", "z_size = 75.0\n", "tilt_x = 0.0\n", "tilt_y = 0.0\n", "tilt_z = 0.0\n", "distance = 180e3\n", "o11 = 1\n", "o12 = 0\n", "o21 = 0\n", "o22 = 1\n", "det_trans, beam_cen_shift, x_distance_shift = anri.geom.detector_transforms(\n", " y_center,\n", " y_size,\n", " tilt_y,\n", " z_center,\n", " z_size,\n", " tilt_z,\n", " tilt_x,\n", " distance,\n", " o11,\n", " o12,\n", " o21,\n", " o22\n", ")" ] }, { "cell_type": "markdown", "id": "bc990af7", "metadata": {}, "source": [ "We get the detector basis vectors in the lab frame:" ] }, { "cell_type": "code", "execution_count": null, "id": "69ad790c", "metadata": {}, "outputs": [], "source": [ "sc_lab, fc_lab, norm_lab = anri.geom.detector_basis_vectors_lab(det_trans, beam_cen_shift, x_distance_shift)" ] }, { "cell_type": "markdown", "id": "0132d302", "metadata": {}, "source": [ "Now we can map into detector space:" ] }, { "cell_type": "code", "execution_count": null, "id": "d12b5d1f", "metadata": {}, "outputs": [], "source": [ "origin_lab = jnp.array([0., 0, 0])\n", "\n", "# get outgoing scattering vector\n", "k_out = q_lab_to_k_out_vec(q_lab, k_in_lab_norm)\n", "# ray-trace it into the detector\n", "sc, fc = raytrace_to_det_vec(k_out, origin_lab, sc_lab, fc_lab, norm_lab)" ] }, { "cell_type": "markdown", "id": "e03b1bbc", "metadata": {}, "source": [ "## Results" ] }, { "cell_type": "code", "execution_count": null, "id": "79ddae35", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(8,8))\n", "ax.scatter(fc, sc)\n", "ax.set_aspect(1)\n", "ax.set(xlabel='Detector fast', ylabel='Detector slow')\n", "# set some sensible detector limits\n", "ax.set_xlim(0, 2048)\n", "ax.set_ylim(0, 2048)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "dbe9faa9", "metadata": {}, "outputs": [], "source": [ "tth, eta = q_lab_to_tth_eta_vec(q_lab, wavelength)" ] }, { "cell_type": "code", "execution_count": null, "id": "003e85c6", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(8,6))\n", "ax.scatter(tth, eta, label='Peaks')\n", "ax.vlines(struc.ringtth, -25, 25, color='red', label='Unit cell')\n", "ax.set(xlabel=r'$2\\theta$', ylabel=r'$\\eta$')\n", "ax.legend(loc='upper right')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "212ea87f", "metadata": {}, "source": [ "## Index the forward-simulated peaks with ImageD11\n", "As a sanity check, we should be able to index the peaks with ImageD11 and recover the UBI. \n", "Let's prepare the columnfile for ImageD11" ] }, { "cell_type": "code", "execution_count": null, "id": "c8773595", "metadata": {}, "outputs": [], "source": [ "import ImageD11.columnfile, ImageD11.parameters, ImageD11.unitcell, ImageD11.indexing, ImageD11.grain\n", "\n", "# make an ImageD11 unitcell from our structure\n", "uc = ImageD11.unitcell.unitcell(struc.lattice_parameters, struc.sgno)\n", "\n", "# prepare a minimal columnfile - just add detector positions and omega angles\n", "cf_obs = ImageD11.columnfile.columnfile(new=True)\n", "cf_obs.nrows = fc.shape[0]\n", "cf_obs.addcolumn(fc, 'fc')\n", "cf_obs.addcolumn(sc, 'sc')\n", "cf_obs.addcolumn(omega_valid, 'omega')\n", "\n", "# prepare parameters object to hold our experiment state\n", "pars = ImageD11.parameters.parameters()\n", "# detector\n", "pars.set('distance', distance)\n", "pars.set('tilt_x', tilt_x)\n", "pars.set('tilt_y', tilt_y)\n", "pars.set('tilt_z', tilt_z)\n", "pars.set('y_size', y_size)\n", "pars.set('z_size', z_size)\n", "pars.set('y_center', y_center)\n", "pars.set('z_center', z_center)\n", "pars.set('o11', o11)\n", "pars.set('o12', o12)\n", "pars.set('o21', o21)\n", "pars.set('o22', o22)\n", "# beam\n", "pars.set('wavelength', wavelength)\n", "# diffractometer\n", "pars.set('chi', chi)\n", "pars.set('wedge', wedge)\n", "pars.set('t_x', 0)\n", "pars.set('t_y', 0)\n", "pars.set('t_z', 0)\n", "pars.set('omegasign', 1)\n", "cf_obs.parameters = pars\n", "print(pars.get_parameters())\n" ] }, { "cell_type": "markdown", "id": "0420945a", "metadata": {}, "source": [ "Now we compute the peak geometry with ImageD11:" ] }, { "cell_type": "code", "execution_count": null, "id": "200658f7", "metadata": {}, "outputs": [], "source": [ "print(cf_obs.titles)\n", "cf_obs.updateGeometry()\n", "print(cf_obs.titles)" ] }, { "cell_type": "markdown", "id": "d2d29e65", "metadata": {}, "source": [ "Sanity check - we should compute the same g-vectors in the sample frame (`q_sample`):" ] }, { "cell_type": "code", "execution_count": null, "id": "6e6c91db", "metadata": {}, "outputs": [], "source": [ "jnp.abs(jnp.stack([cf_obs.gx, cf_obs.gy, cf_obs.gz], axis=1) - q_sample_valid).max()" ] }, { "cell_type": "markdown", "id": "7dfbc22a", "metadata": {}, "source": [ "Now let's set up our indexer and run it:" ] }, { "cell_type": "code", "execution_count": null, "id": "7e435534", "metadata": {}, "outputs": [], "source": [ "ImageD11.indexing.loglevel = 3\n", "idx = ImageD11.indexing.indexer_from_colfile_and_ucell(cf_obs, uc)\n", "idx.ds_tol = 0.005\n", "idx.assigntorings()\n", "idx.hkl_tol = 0.01\n", "idx.score_all_pairs()" ] }, { "cell_type": "code", "execution_count": null, "id": "407ef3b6", "metadata": {}, "outputs": [], "source": [ "id11_ubi = idx.ubis[0]\n", "id11_grain = ImageD11.grain.grain(id11_ubi)\n", "# B matrices should be very similar:\n", "print(jnp.abs(id11_grain.B - struc.B).max())\n", "# the U matrices that we get back should be the same under symmetry:\n", "dU = id11_grain.U.T @ U\n", "print(dU)" ] }, { "cell_type": "markdown", "id": "df0604f9", "metadata": {}, "source": [ "An even simpler check is whether our `UBI` from `Anri` indexes the g-vectors from ImageD11:" ] }, { "cell_type": "code", "execution_count": null, "id": "6da9ad11", "metadata": {}, "outputs": [], "source": [ "from ImageD11.cImageD11 import score\n", "gve_id11 = jnp.stack([cf_obs.gx, cf_obs.gy, cf_obs.gz], axis=1)\n", "score_result = score(ubi=jnp.linalg.inv(U @ struc.B), gv=gve_id11, tol=0.01)\n", "print(f'Score result: {score_result}, Peaks in dataset: {cf_obs.nrows}')" ] }, { "cell_type": "code", "execution_count": null, "id": "2eb48d77", "metadata": {}, "outputs": [], "source": [ "end = time.time()\n", "print(f'Took {(end - start):.1f} seconds')" ] }, { "cell_type": "markdown", "id": "e02c545c", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.14.2" } }, "nbformat": 4, "nbformat_minor": 5 }