NumPyro Example: Transmission Spectroscopy

# This should install any required dependencies assuming JAX has already been installed
try:
    import luas
except ImportError:
    !git clone https://github.com/markfortune/luas.git
    %cd luas
    %pip install .
    %cd ..

try:
    import jaxoplanet
except ImportError:
    %pip install -q jaxoplanet==0.0.1

try:
    import numpyro
except ImportError:
    %pip install -q numpyro
    
try:
    import corner
except ImportError:
    %pip install -q corner
    
try:
    import arviz as az
except ImportError:
    %pip install -q arviz

NumPyro Example: Transmission Spectroscopy#

NOTE: This tutorial notebook is currently in development. Expect significant changes in the future

This notebook provides to tutorial on how to use NumPyro to perform spectroscopic transit light curve fitting. We will first go through how to generate transit light curves in jax, then we will create synthetic noise which will be correlated in both wavelength and time. Finally we will use NumPyro to fit the noise contaminated light curves and recover the input transmission spectrum. You can run this notebook yourself on Google Colab by clicking the rocket icon in the top right corner.

NumPyro has the advantage of being written in jax which makes implementing some things with luas like parallelisation simpler compared to PyMC. However, arguably it is less user-friendly than PyMC and has less flexibility at times. For example, unlike PyMC blocked Gibbs sampling is difficult to implement, in addition there are less options for using non-gradient based MCMC methods such as slice sampling.

If you have a GPU available (including a GPU on Google Colab) then jax should detect this and run on it automatically. Running jax.devices() tells us what jax has found to run on. If there is a GPU available but it is not detected this could indicate jax has not been correctly installed for GPU. Also note if running on Google Colab that the T4 GPUs have poor performance at double precision floating point and may not show significant speed-ups, while more modern hardware such as NVIDIA Tesla V100s or better tend to show significant performance improvements

import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import os.path
from jaxoplanet.orbits import KeplerianOrbit
from jaxoplanet.light_curves import QuadLightCurve
import arviz as az
import pandas as pd
import luas
import numpyro
import corner
import logging

# Useful to set when looking at retrieved values for many parameters
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

# This helps give more information on what NumPyro is doing during inference
logging.getLogger().setLevel(logging.INFO)

# Running this at the start of the runtime ensures jax uses 64-bit floating point numbers
# as jax uses 32-bit by default
jax.config.update("jax_enable_x64", True)

# This will list available CPUs/GPUs JAX is picking up
print("Available devices:", jax.devices())

# This can be commented out to allow NumPyro to run parallel chains on CPU but it may not improve performance
# numpyro.set_host_device_count(2)
Available devices: [CpuDevice(id=0)]
INFO:absl:Remote TPU is not linked into jax; skipping remote TPU.
INFO:absl:Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
INFO:absl:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:absl:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:absl:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'

First we will use jaxoplanet which has its own documentation and tutorials here. Let’s create a single light curve with quadratic limb darkening. We will use the transit_light_curve function from luas.exoplanet which we have included below to show what it does.

def transit_light_curve(par, t):
    """Uses the package `jaxoplanet <https://github.com/exoplanet-dev/jaxoplanet>`_ to calculate
    transit light curves using JAX assuming quadratic limb darkening and a simple circular orbit.
    
    This particular function will only compute a single transit light curve but JAX's vmap function
    can be used to calculate the transit light curve of multiple wavelength bands at once.
    
    Args:
        par (PyTree): The transit parameters stored in a PyTree/dictionary (see example above).
        t (JAXArray): Array of times to calculate the light curve at.
            
    Returns:
        JAXArray: Array of flux values for each time input.
        
    """
    
    light_curve = QuadLightCurve.init(u1=par["u1"], u2=par["u2"])
    orbit = KeplerianOrbit.init(
        time_transit=par["T0"],
        period=par["P"],
        semimajor=par["a"],
        impact_param=par["b"],
        radius=par["rho"],
    )
    
    flux = (par["Foot"] + 24*par["Tgrad"]*(t-par["T0"]))*(1+light_curve.light_curve(orbit, t)[0])
    
    return flux
from luas.exoplanet import transit_light_curve

# Let's use the literature values used in Gibson et al. (2017) to start us off
mfp = {
    "T0":0.,     # Central transit time
    "P":3.4,     # Period (days)
    "a":8.,      # Semi-major axis to stellar ratio aka a/R*
    "rho":0.1,   # Radius ratio rho aka Rp/R*
    "b":0.5,     # Impact parameter
    "u1":0.5,    # First quadratic limb darkening coefficient
    "u2":0.1,    # Second quadratic limb darkening coefficient
    "Foot":1.,   # Baseline flux out of transit
    "Tgrad":0.   # Gradient in baseline flux (hrs^-1)
}

# Generate 100 evenly spaced time points (in units of days)
N_t = 100
x_t = jnp.linspace(-0.15, 0.15, N_t)

plt.plot(x_t, transit_light_curve(mfp, x_t), "k.")
plt.xlabel("Time (days)")
plt.ylabel("Normalised Flux")
plt.show()
../_images/c9a37acce6aa2fbb18b4455059dad0baaafea93d558fae8e781d0d537dcb9676.png

Now we want to take our jax function which has been written to generate a 1D light curve in time and instead create separate light curves in each wavelength. We also will want to share some parameters between light curves (e.g. impact parameter b, system scale a/Rs) while varying other parameters for each wavelength (e.g. radius ratio Rp/Rs).

We could of course write a for loop but for loops can be slow to compile with jax. A more efficient option is to make use of the jax.vmap function which allows us to “vectorise” our function from 1D to 2D:

# First we must tell JAX which parameters of the function we want to vary for each light curve
# and which we want to be shared between light curves
transit_light_curve_vmap = jax.vmap(
    # First argument is the function to vectorise
    transit_light_curve, 
    
    # Specify which parameters to share and which to vary for each light curve
    in_axes=(
        {
        # If a parameter is to be shared across each light curve then it should be set to None
        "T0":None, "P":None, "a":None, "b":None,

        # Parameters which vary in wavelength are given the dimension of the array to expand along
        # In this case we are expanding from 0D arrays to 1D arrays so this must be 0
        "rho":0, "u1":0, "u2":0, "Foot":0, "Tgrad":0
        },
        # Also must specify that we will share the time array (the second function parameter)
        # between light curves
         None,  
    ),
    
    # Specify the output dimension to expand along, this will default to 0 anyway
    # Will output extra flux values for each light curve as additional rows
    out_axes = 0,
)

# Let's define a wavelength range of our data, this isn't actually used by our mean function
# but we will use it for defining correlation in wavelength later
N_l = 16 # Feel free to vary the number of light curves, even >100 light curves may perform quite efficiently
x_l = np.linspace(4000, 7000, N_l)

# Now we define our 2D transit parameters, let's just assume everything is constant in wavelength
# Note that the parameters which vary in wavelength are now arrays of size N_l
mfp_2D = {
    "T0":0.,                         # Central transit time
    "P":3.4,                         # Period (days)
    "a":8.,                          # Semi-major axis to stellar ratio aka a/R*
    "rho":0.1*np.ones(N_l),          # Radius ratio rho aka Rp/R*
    "b":0.5,                         # Impact parameter
    "u1":np.linspace(0.7, 0.4, N_l), # First quadratic limb darkening coefficient
    "u2":np.linspace(0.1, 0.2, N_l), # Second quadratic limb darkening coefficient
    "Foot":1.*np.ones(N_l),          # Baseline flux out of transit
    "Tgrad":0.*np.ones(N_l)          # Gradient in baseline flux (days^-1)
}

# Call our new vmap of transit_light_curve to simultaneously compute light curves for all wavelengths
transit_model_2D = transit_light_curve_vmap(mfp_2D, x_t)

# Plot the data on the left and the multiwavelength model on the right
plt.imshow(transit_model_2D, aspect = 'auto')
plt.xlabel("Time (days)")
plt.ylabel("Light Curve")
plt.show()
../_images/807b723d93187fd6b4ec50d1d7846933cdd4b8b53ce6f213613af63b796cd124.png

The form of a mean function with luas.GP is f(par, x_l, x_t) where par is a PyTree, x_l is a JAXArray of inputs that lie along the wavelength/vertical dimension of the data (e.g. an array of wavelength values) and x_t is a JAXArray of inputs that lie along the time/horizontal dimension (e.g. an array of timestamps). Our transit model does not require wavelength values to be computed but we still write our wrapper function to be of this form.

We will also switch to the Kipping (2013) limb darkening parameterisation as it makes it easier to place simple bounds on these parameters and has the benefit of placing a uniform prior on the physically allowed space of limb darkening parameters.

Also feel free to use the luas.exoplanet.transit_2D function instead which is similar to the below function with the exception that it fits for transit depth \(d = \rho^2\) instead of radius ratio \(\rho\)

from luas.exoplanet import ld_to_kipping, ld_from_kipping

def transit_light_curve_2D(p, x_l, x_t):
    
    # vmap requires that we only input the parameters which have been explicitly defined how they vectorise
    transit_params = ["T0", "P", "a", "rho", "b", "Foot", "Tgrad"]
    mfp = {k:p[k] for k in transit_params}
    
    # Calculate limb darkening coefficients from the Kipping (2013) parameterisation.
    mfp["u1"], mfp["u2"] = ld_from_kipping(p["q1"], p["q2"])
    
    # Use the vmap of transit_light_curve to calculate a 2D array of shape (M, N) of flux values
    # For M wavelengths and N time points.
    return transit_light_curve_vmap(mfp, x_t)


# Switch to Kipping parameterisation
if "u1" in mfp_2D:
    mfp_2D["q1"], mfp_2D["q2"] = ld_to_kipping(mfp_2D["u1"], mfp_2D["u2"])
    del mfp_2D["u1"]
    del mfp_2D["u2"]

# This should produce the same result as transit_light_curve_vmap as we have only changed the limb darkening parameterisation
M = transit_light_curve_2D(mfp_2D, x_l, x_t)

Now that we have our mean function created, it’s time to start building a kernel function. We will try keep things simple and use a squared exponential kernel for correlation in both time and wavelength, and white noise which varies in amplitude between different light curves.

\[ \begin{equation} \mathbf{K}_{ij} = h^2 \exp\left(-\frac{|\lambda_i - \lambda_j|^2}{2 l_\mathrm{\lambda}^2}\right) \exp\left(-\frac{|t_i - t_j|^2}{2 l_{t}^2}\right) + \sigma_{\lambda_i}^2\delta_{\lambda_i \lambda_j} \delta_{t_i t_j} \end{equation} \]

This kernel contains a correlated noise component with height scale \(h\), length scale in wavelength \(l_\mathrm{\lambda}\) and length scale in time \(l_t\). There is also a white noise term which can have different white noise amplitudes at different wavelengths.

We will need to write this kernel as separate wavelength and time kernel functions multiplied together satisfying:

\[ \begin{equation} \text{K}(\Delta \lambda, \Delta t) = \text{K}_{l}(\Delta \lambda) \otimes \text{K}_{t}(\Delta t) + \text{S}_{l}(\Delta \lambda) \otimes \text{S}_{t}(\Delta t) \end{equation} \]

We can do this by choosing the equations below for our kernel functions which generate each component matrix.

\[ \begin{equation} \text{K}_l(\lambda_i, \lambda_j) = h^2 \exp\left(-\frac{|\lambda_i - \lambda_j|^2}{2 l_\mathrm{\lambda}^2}\right) \end{equation} \]
\[ \begin{equation} \text{K}_t(t_i, t_j) = \exp\left(-\frac{|t_i - t_j|^2}{2 l_{t}^2}\right) \end{equation} \]
\[ \begin{equation} \text{S}_l(\lambda_i, \lambda_j) = \sigma_{\lambda_i}^2\delta_{\lambda_i \lambda_j} \end{equation} \]
\[ \begin{equation} \text{S}_t(t_i, t_j) = \delta_{t_i t_j} \end{equation} \]

Note that for numerical stability reasons it’s important that the \(S_l\) and \(S_t\) matrices are well-conditioned (i.e. can be stably inverted) while \(K_l\) and \(K_t\) do not need to be well-conditioned (or even invertible). The matrices which contain values added along the diagonal should be well-conditioned as this “regularises” the matrices, while a squared exponential kernel with nothing added along the diagonal is unlikely to be well-conditioned.

from luas import kernels
from luas import LuasKernel, GeneralKernel

# We implement each of these kernel functions below using the luas.kernels module
# for an implementation of the squared exponential kernel

# The wavelength kernel functions take the wavelength regression variable(s) x_l as input (of shape (N_l) or (d_l, N_l))
def Kl_fn(hp, x_l1, x_l2, wn = True):
    Kl = jnp.exp(2*hp["log_h"])*kernels.squared_exp(x_l1, x_l2, jnp.exp(hp["log_l_l"]))
    return Kl

# The time kernel functions take the time regression variable(s) x_t as input (of shape (N_t) or (d_t, N_t))
def Kt_fn(hp, x_t1, x_t2, wn = True):
    return kernels.squared_exp(x_t1, x_t2, jnp.exp(hp["log_l_t"]))

# For both the Sl and St functions we set a decomp attribute to "diag" because they produce diagonal matrices
# This speeds up the log likelihood calculations as it tells luas these matrices are easy to eigendecompose
# But don't do this for the Kl and Kt functions even if they produce diagonal matrices unless you know what you are doing
# This is because you are telling luas that the transformations of Kl and Kt are diagonal, not Kl and Kt themselves

# If the wn keyword argument is True then white noise should be included (doesn't affect most matrices)
# This is used by gp.predict when performing Gaussian process prediction
# Note it does not matter that Sl is not invertible without white noise
def Sl_fn(hp, x_l1, x_l2, wn = True):
    Sl = jnp.zeros((x_l1.shape[-1], x_l2.shape[-1]))
    
    if wn:
        # If we are including white noise then safe to assume Sl is a square matrix for any calculations in luas.GP
        Sl += jnp.diag(jnp.exp(2*hp["log_sigma"])) # Assumes hp["log_sigma"] is an array of size N_l

    return Sl
Sl_fn.decomp = "diag" # Sl is a diagonal matrix

def St_fn(p, x_t1, x_t2, wn = True):
    return jnp.eye(x_t1.shape[-1])
St_fn.decomp = "diag" # St is a diagonal matrix

# Build a LuasKernel object using these component kernel functions
# The full covariance matrix applied to the data will be K = Kl KRON Kt + Sl KRON St
kernel = LuasKernel(Kl = Kl_fn, Kt = Kt_fn, Sl = Sl_fn, St = St_fn,
                    
                    # Can select whether to use previously calculated eigendecompositions when running MCMC
                    # Performs an additional check in each step to see if each component covariance matrix has changed since last step
                    # Useful when doing blocked Gibbs or if fixing some hyperparameters
                    use_stored_values = True, 
                   )

# We can also create a kernel object with the same kernel as LuasKernel but without the kronecker product optimisations
# We will use this to show they produce the same answers and to compare runtimes
# While it does take the LuasKernel.K function, this simply takes Kl_fn, Kt_fn, Sl_fn and St_fn and kronecker products the results together
# So it builds the same covariance matrix but the log likelihood calculations are completely different
general_kernel = GeneralKernel(K = kernel.K)

The LuasKernel.visualise_covariance_matrix method should be a useful way of visualising each of the four component matrices and ensure everything you have written looks right

# Some sample hyperparameters
hp = {
    "log_h":jnp.log(2e-3),    # log height scale of correlated noise
    "log_l_l":jnp.log(1000.), # log wavelength length scale
    "log_l_t":jnp.log(0.011), # log time length scale
    "log_sigma":jnp.log(5e-4)*jnp.ones(N_l), # white noise amplitude at each wavelength
}

kernel.visualise_covariance_matrix(hp, x_l, x_t);
../_images/cbcc7f176db2fe4f7a468e6613b32f12ad2b8084298ee3b4909a8fba42173f44.png

We can test the numerical stability of both the LuasKernel and GeneralKernel with a simple test where we multiply a random normal vector by the inverse of the covariance matrix followed by multiplying by the covariance matrix. This should result in the original random normal vector we started with. We should expect some level of floating point errors but they are likely to be negligible relative to the noise level in a dataset. So far I have yet to see the LuasKernel perform poorly for any valid covariance matrix. If having issues here it is worth making sure that each covariance matrix generates a symmetric covariance matrix and that \(S_l\) and \(S_t\) are invertible.

# Generate a random normal vector
random_vec = np.random.normal(size = (N_l, N_t))

# Calculate alpha = K_inv @ random_vec
alpha = kernel.K_inv_by_vec(hp, x_l, x_t, random_vec)

# Calculate K @ alpha which should equal random_vec
K_K_inv_random_vec = kernel.K_by_vec(hp, x_l, x_t, alpha)

# Calculate the numerical errors from the original vector
calc_err = random_vec - K_K_inv_random_vec
print(f"Largest deviations (LuasKernel): {calc_err.min()}, {calc_err.max()}")

# Perform the same calculation with the general_kernel which builds the full covariance matrix
# and uses Cholesky decomposition for inverting the covariance matrix
K_inv_random_vec = general_kernel.K_inv_by_vec(hp, x_l, x_t, random_vec)
K_K_inv_random_vec = general_kernel.K_by_vec(hp, x_l, x_t, K_inv_random_vec)

calc_err = random_vec - K_K_inv_random_vec
print(f"Largest deviations (GeneralKernel): {calc_err.min()}, {calc_err.max()}")
Largest deviations (LuasKernel): -2.942091015256665e-13, 2.4857893521357255e-13
Largest deviations (GeneralKernel): -3.615996391204135e-13, 3.6885078324999654e-13

Take a random noise draw from this covariance matrix. Try playing around with these values to see the effect varying each parameter has on the noise generated.

hp = {"log_h":jnp.log(2e-3), "log_l_l":jnp.log(1000.), "log_l_t":jnp.log(0.011),
      "log_sigma":jnp.log(5e-4)*jnp.ones(N_l),}

sim_noise = kernel.generate_noise(hp, x_l, x_t)

plt.imshow(sim_noise, aspect = 'auto', extent = [x_t[0], x_t[-1], x_l[-1], x_l[0]])
plt.xlabel("Time (days)")
plt.ylabel("Wavelength ($\AA$)")
plt.show()
../_images/e0e8760f523d48169b82254d1013fa5ddef05d2903a415d607a4b583ade3459d.png

Combining our light curves and noise model we can generate synthetic light curves contaminated by systematics correlated in time and wavelength

def plot_lightcurves(x_t, M, Y, sep = 0.008):
    """Quick function to visualise light curves and the residuals after subtraction of transit model
    
    """
    N_l = x_l.shape[-1]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 12), sharey = True)
    for i in range(N_l):
        ax1.plot(x_t, Y[i, :] + np.arange(0, -sep*N_l, -sep)[i], 'bo', ms = 3)
        
        ax1.plot(x_t, M[i, :] + np.arange(0, -sep*N_l, -sep)[i], 'k-', ms = 3)
        ax2.plot(x_t, Y[i, :] - M[i, :] + np.arange(1, 1-sep*N_l, -sep)[i], 'b.', ms = 3)
        ax2.plot(x_t, np.arange(1, 1-sep*N_l, -sep)[i]*np.ones_like(x_t), 'k-', ms = 3)

u1_sim, u2_sim = jnp.linspace(0.7, 0.4, N_l), jnp.linspace(0.1, 0.2, N_l)
q1_sim, q2_sim = ld_to_kipping(u1_sim, u2_sim)

# All parameters used to generate the synthetic data set with noise
# luas currently assumes all inputs are arrays so convert floats to size 1 JAXArray
# Unlike PyMC NumPyro is fine taking JAXArrays as input
p_sim = {
    # Mean function parameters
    "T0":0.*jnp.ones(1),                      # Central transit time
    "P":3.4*jnp.ones(1),                      # Period (days)
    "a":8.*jnp.ones(1),                       # Semi-major axis to stellar ratio aka a/R*
    "rho":0.1*jnp.ones(N_l),                  # Radius ratio rho aka Rp/R* for each wavelength
    "b":0.5*jnp.ones(1),                      # Impact parameter
    "q1":q1_sim,                              # First quadratic limb darkening coefficient for each wavelength
    "q2":q2_sim,                              # Second quadratic limb darkening coefficient for each wavelength
    "Foot":1.*jnp.ones(N_l),                  # Baseline flux out of transit for each wavelength
    "Tgrad":0.*jnp.ones(N_l),                 # Gradient in baseline flux for each wavelength (days^-1)
    
    # Hyperparameters
    "log_h":jnp.log(5e-4)*jnp.ones(1),        # log height scale
    "log_l_l":jnp.log(1000.)*jnp.ones(1),     # log length scale in wavelength
    "log_l_t":jnp.log(0.011)*jnp.ones(1),     # log length scale in time
    "log_sigma":jnp.log(5e-4)*jnp.ones(N_l),  # log white noise amplitude for each wavelength
}

transit_signal = transit_light_curve_2D(p_sim, x_l, x_t)
sim_noise = kernel.generate_noise(p_sim, x_l, x_t)
Y = transit_signal*(1 + sim_noise)

plot_lightcurves(x_t, transit_signal, Y)
../_images/2bb83b21296d5aad70392a3ee9e5cb35fadb35c622665ac031cbeadb1c53a852.png

We may also want to define a logPrior function. While PyMC can also be used to define priors on parameters, it can be useful to let luas.GP handle non-uniform priors when it comes to MCMC tuning (as will be shown later).

The logPrior function input to luas.GP must be of the form logPrior(params) i.e. it may only take the PyTree of mean function parameters and hyperparameters as input.

# Places priors on the system scale and impact parameter
a_mean = p_sim["a"]
a_std = 0.1
b_mean = p_sim["b"]
b_std = 0.01

# Set some limb darkening priors, normally these might be generated from a package like LDTk
u1_mean = u1_sim
u1_std = 0.01
u2_mean = u2_sim
u2_std = 0.01

def logPrior(p):
    logPrior = -0.5*((p["a"] - a_mean)/a_std)**2
    logPrior += -0.5*((p["b"] - b_mean)/b_std)**2
    
    u1, u2 = ld_from_kipping(p["q1"], p["q2"])
    u1_priors = -0.5*((u1 - u1_mean)/u1_std)**2
    u2_priors = -0.5*((u2 - u2_mean)/u2_std)**2
    
    logPrior += u1_priors.sum() + u2_priors.sum()

    return logPrior.sum()

print("Log prior at simulated values:", logPrior(p_sim))
Log prior at simulated values: -1.2518544638517033e-28

We now have enough to define our luas.GP object and use it to try recover the original transmission signal injected into the correlated noise.

from luas import GP
from copy import deepcopy

# Initialise our GP object
# Make sure to include the mean function and log prior function if you're using them
gp = GP(kernel,  # Kernel object to use
        x_l,     # Regression variable(s) along wavelength/vertical dimension
        x_t,     # Regression variable(s) along time/horizontal dimension
        mf = transit_light_curve_2D,  # (optional) mean function to use, defaults to zeros
        logPrior = logPrior           # (optional) log prior function, defaults to zero
       )

# Initialise our starting values as the true simulated values
p_initial = deepcopy(p_sim)

# Convenient function for plotting the data and the GP fit to the data
gp.plot(p_initial, Y)
print("Starting log posterior value:", gp.logP(p_initial, Y))
Starting log posterior value: 9754.21597876404
../_images/0e2b02f801474015dbd1d0c67a17ba9a621cefd7d7f485d6c8cdbe65dd97f383.png

Let’s compare the runtime speed-up of using the optimisations in the LuasKernel. Comparing runtimes with JAX is a little bit more involved than normal. First we want to pre-compile the functions using jax.jit so that we aren’t timing how long each function takes to compile. We then must make sure we run the JIT-compiled functions once before benchmarking as the compilation step will be done based on the array sizes given in that first run. We then can time the JIT-compiled functions but make sure to include block_until_ready due to jax using asynchronous dispatch (i.e. it waits until the calculation has actually finished).

Note that while the runtime improvement for this small data set may not be substantial, the methods have very different scaling with GeneralKernel.logP scaling as \(\mathcal{O}(N_l^3 N_t^3)\) and LuasKernel.logP scaling as \(\mathcal{O}(N_l^3 + N_t^3 + N_l N_t (N_l + N_t))\), so for larger data sets the difference may be much larger. GeneralKernel.logP may also crash for larger data sets due to its higher memory requirements.

# Create a similar GP object but using general Cholesky calculations in GeneralKernel instead of using LuasKernel
gp_general = GP(general_kernel,  # Kernel object to use
        x_l,     # Regression variable(s) along wavelength/vertical dimension
        x_t,     # Regression variable(s) along time/horizontal dimension
        mf = transit_light_curve_2D,  # (optional) mean function to use, defaults to zeros
        logPrior = logPrior           # (optional) log prior function, defaults to zero
       )

# Let's JIT compile each function before running (this should already be done in GP.py but is good practice anyway)
general_jit = jax.jit(gp_general.logP)
luas_jit = jax.jit(gp.logP)

# Run each calculation before benchmarking, note they should give the same answer up to floating point errors
print("GeneralKernel Calculation:", general_jit(p_initial, Y))
print("LuasKernel Calculation:", luas_jit(p_initial, Y))

# Benchmark each method
print("\nRuntime for GeneralKernel: ")
%timeit general_jit(p_initial, Y).block_until_ready()

print("Runtime for LuasKernel: ")
%timeit luas_jit(p_initial, Y).block_until_ready()
GeneralKernel Calculation: 9754.215978764041
LuasKernel Calculation: 9754.21597876404

Runtime for GeneralKernel: 
71.3 ms ± 9.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Runtime for LuasKernel: 
12.7 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

Now let’s begin using NumPyro to perform a best-fit. We’ve set the true simulated values as our initial values so we should already be close to the optimal log posterior value but for real data this won’t always be the case. Similar to the approach in Fortune et al. (2024), we will not be performing any white light curve analysis but instead will be joint-fitting all spectroscopic light curves simultaneously.

import numpyro
import numpyro.distributions as dist
from luas.numpyro_ext import LuasNumPyro
from copy import deepcopy

# Let's define some bounds
min_log_l_l = np.log(np.diff(x_l).min())
max_log_l_l = np.log(50*(x_l[-1] - x_l[0]))
min_log_l_t = np.log(np.diff(x_t).min())
max_log_l_t = np.log(3*(x_t[-1] - x_t[0]))

param_bounds = {
                # Bounds from Kipping (2013) at just between 0 and 1
                "q1":[jnp.array([0.]*N_l), jnp.array([1.]*N_l)],
                "q2":[jnp.array([0.]*N_l), jnp.array([1.]*N_l)],
    
                # Can optionally include bounds on other mean function parameters but often they will be well constrained by the data
                "rho":[jnp.array([0.]*N_l), jnp.array([1.]*N_l)],
    
                # Sometimes prior bounds on hyperparameters are important for sampling
                # However their choice can sometimes affect the results so use with caution
                "log_h":   [jnp.log(1e-6)*np.ones(1), jnp.log(1)*jnp.ones(1)],
                "log_l_l": [min_log_l_l*jnp.ones(1), max_log_l_l*jnp.ones(1)],
                "log_l_t": [min_log_l_t*jnp.ones(1), max_log_l_t*jnp.ones(1)],
                "log_sigma":[jnp.log(1e-6)*jnp.ones(N_l), jnp.log(1e-2)*jnp.ones(N_l)],
}

def transit_model(Y):
    # Makes of copy of any parameters to be kept fixed during sampling
    var_dict = deepcopy(p_initial)
    
    # Specify the parameters we've given bounds for
    var_dict["rho"] = numpyro.sample("rho", dist.Uniform(low = param_bounds["rho"][0],
                                                         high =param_bounds["rho"][1]))
    var_dict["log_h"] = numpyro.sample("log_h", dist.Uniform(low = param_bounds["log_h"][0],
                                                             high = param_bounds["log_h"][1]))
    var_dict["log_l_l"] = numpyro.sample("log_l_l", dist.Uniform(low = param_bounds["log_l_l"][0],
                                                                 high = param_bounds["log_l_l"][1]))
    var_dict["log_l_t"] = numpyro.sample("log_l_t", dist.Uniform(low = param_bounds["log_l_t"][0],
                                                                 high = param_bounds["log_l_t"][1]))
    var_dict["log_sigma"] = numpyro.sample("log_sigma", dist.Uniform(low = param_bounds["log_sigma"][0],
                                                                     high = param_bounds["log_sigma"][1]))
    var_dict["q1"] = numpyro.sample("q1", dist.Uniform(low = param_bounds["q1"][0], high = param_bounds["q1"][1]))
    var_dict["q2"] = numpyro.sample("q2", dist.Uniform(low = param_bounds["q2"][0], high = param_bounds["q2"][1]))
    
    # Specify the unbounded parameters
    var_dict["T0"] = numpyro.sample("T0", dist.ImproperUniform(dist.constraints.real, (),
                                                               event_shape = (1,)))
    var_dict["a"] = numpyro.sample("a", dist.ImproperUniform(dist.constraints.real, (),
                                                             event_shape = (1,)))
    var_dict["b"] = numpyro.sample("b", dist.ImproperUniform(dist.constraints.real, (),
                                                             event_shape = (1,)))
    var_dict["Foot"] = numpyro.sample("Foot", dist.ImproperUniform(dist.constraints.real, (),
                                                                   event_shape = (N_l,)))
    var_dict["Tgrad"] = numpyro.sample("Tgrad", dist.ImproperUniform(dist.constraints.real, (),
                                                                     event_shape = (N_l,)))

    numpyro.sample("log_like", LuasNumPyro(gp = gp, var_dict = var_dict), obs = Y)

Now we are all set up to start performing a best-fit using NumPyro. Some of the variables here such as step_size and num_steps may need to be tweaked depending on the problem. Optimisation with NumPyro can be a bit awkward to use so it may help to use the numpyro-ext package which contains extensions to NumPyro, including an easier to use optimiser

from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoLaplaceApproximation
from numpyro.infer.initialization import init_to_value
from jax import random

# Define step size and number of optimisation steps
step_size = 1e-5
num_steps = 5000

# Uses adam optimiser and a Laplace approximation calculated from the hessian of the log posterior as a guide
optimizer = numpyro.optim.Adam(step_size=step_size)
guide = AutoLaplaceApproximation(transit_model, init_loc_fn = init_to_value(values=p_initial))

# Create a Stochastic Variational Inference (SVI) object with NumPyro
svi = SVI(transit_model, guide, optimizer, loss=Trace_ELBO())

# Run the optimiser and get the median parameters
svi_result = svi.run(random.PRNGKey(0), num_steps, Y)
params = svi_result.params
p_fit = guide.median(params)

# Combine best-fit values with fixed values for log posterior calculation
p_opt = deepcopy(p_initial)
p_opt.update(p_fit)

print("Starting log posterior value:", gp.logP(p_initial, Y))
print("New optimised log posterior value:", gp.logP(p_opt, Y))
Starting log posterior value: 9754.21597876404
New optimised log posterior value: 9784.485587030633
100%|█████████████████████████████████████████████████████| 5000/5000 [00:56<00:00, 89.07it/s, init loss: -9640.8718, avg. loss [4751-5000]: -9671.0989]

Although this probably won’t be needed for the simulated data and will likely clip no data points, luas.GP comes with a sigma_clip method which performs 2D Gaussian process regression and can clip outliers that deviate from a given significance value and replace them with the GP predictive mean at those locations (note we need to maintain a complete grid structure and cannot remove these data points).

# This function will return a JAXArray of the same shape as Y
# but with outliers replaced with interpolated values
Y_clean = gp.sigma_clip(p_opt, # Make sure to perform sigma clipping using a good fit to the data
                        Y,     # Observations JAXArray
                        5.     # Significance level in standard deviations to clip at
                       )
Number of outliers clipped =  0
../_images/ff38b6f1a1e1b7f156666416a0cfd7a0fdc78c6c105fc7a0fb33b8487b114073.png

For MCMC tuning with large numbers of parameters, it can be very helpful to use the Laplace approximation to select a good choice of tuning matrix or “mass matrix” for No U-Turn Sampling (NUTS). This can often return quite accurate approximations of the covariance matrix of the posterior - especially when most of the parameters are well constrained.

If outliers were clipped then optimisation may need to be ran again on the cleaned data before running this step as the Laplace approximation should be performed at the maximum of the posterior. Also note that we use the gp.laplace_approx_with_bounds method instead of gp.laplace_approx because we have bounds on some of our parameters which means NumPyro will perform a transformation on these parameters we would like our Laplace approximation to take account of. See the gp.laplace_approx_with_bounds documentation for more details.

# Returns the covariance matrix returned by the Laplace approximation
# Also returns a list of parameters which is the order the array is in
# This matches the way jax.flatten_util.ravel_pytree will sort the parameter PyTree into
cov_mat, ordered_param_list = gp.laplace_approx_with_bounds(
    p_opt,               # Make sure to use best-fit values
    Y_clean,             # The observations being fit
    param_bounds,        # Specify the same bounds that will be used for the MCMC
    fixed_vars = ["P"],  # Make sure to specify fixed parameters as otherwise they are marginalised over
    return_array = True, # May optionally return a nested PyTree if set to False which can be more readable
    regularise = True,   # Often necessary to regularise values that return negative covariance
    large = False,       # Setting this to True is more memory efficient which may be needed for large data sets
)

# This function will output information on what regularisation has been performed
# And will mention if there are remaining negative values along the diagonal of the covariance matrix
# It does not however check if the covariance matrix is invertible
No regularisation needed to remove negative values along diagonal of covariance matrix.

Build our model again using the same parameters being varied and sample all parameters using NUTS with NumPyro

from numpyro.infer import MCMC, NUTS
from jax import random

# NumPyro makes use of the jax.random module to handle randomness
rng_key, rng_key_predict = random.split(random.PRNGKey(0))

# Define the sampler we will use
nuts_step = NUTS(transit_model, # Our transit model specified earlier
                   init_strategy = init_to_value(values = p_opt), # Often works well to initialise near best-fit values
                   inverse_mass_matrix = cov_mat, # Inverse mass matrix is the same as the tuning covariance matrix
                   adapt_mass_matrix=False, # Often Laplace approximation works better than trying to tune many parameters
                   dense_mass = True,       # Need a dense mass matrix to account for correlations between parameters
                   regularize_mass_matrix = False, # Mass matrix should already be regularised
                  )

mcmc = MCMC(
    nuts_step,         # Sampler to use
    num_warmup=1000,   # Number of warm-up steps
    num_samples=1000,  # Number of samples post warm-up
    num_chains=2,      # Number of chains to run
)

# Run inference with given random seed and any arguments to the NumPyro model
mcmc.run(rng_key, Y)

# Generate an arviz inference data object from MCMC chains
idata = az.from_numpyro(mcmc)

# Saves the inference object
#idata.to_json("MCMC_chains.json");
/var/folders/lp/chmp1tb92f5_x_wdckqwk34c0000gn/T/ipykernel_83950/2518427083.py:16: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  mcmc = MCMC(
sample: 100%|████████████████████████████████████████████████████████████| 2000/2000 [03:47<00:00,  8.79it/s, 15 steps of size 3.93e-01. acc. prob=0.88]
sample: 100%|████████████████████████████████████████████████████████████| 2000/2000 [04:43<00:00,  7.06it/s, 15 steps of size 4.36e-01. acc. prob=0.86]

Print a summary of the samples using arviz. Important values to look at for convergence are that the effective sample size of the bulk of the distribution (ess_bulk) and the tail of the distribution (ess_tail) are at least ~500-1000 and that the Gelman-Rubin r_hat statistic is less than ~1.01 for each parameter.

az.summary(idata, round_to = 4)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
Foot[0] 0.9999 0.0002 0.9996 1.0002 0.0000 0.0000 2855.2725 1357.2726 1.0003
Foot[1] 0.9999 0.0002 0.9996 1.0002 0.0000 0.0000 2767.2503 1272.8925 0.9997
Foot[2] 1.0000 0.0002 0.9997 1.0004 0.0000 0.0000 2944.1248 1494.2252 1.0012
Foot[3] 1.0000 0.0002 0.9997 1.0003 0.0000 0.0000 2709.6956 1357.4303 1.0000
Foot[4] 0.9999 0.0002 0.9996 1.0002 0.0000 0.0000 2937.2029 1377.1856 0.9998
Foot[5] 1.0000 0.0002 0.9997 1.0004 0.0000 0.0000 2903.3169 1503.3115 1.0004
Foot[6] 1.0001 0.0002 0.9998 1.0005 0.0000 0.0000 3000.9795 1400.1627 1.0003
Foot[7] 1.0002 0.0002 0.9998 1.0005 0.0000 0.0000 2891.4591 1193.4521 0.9994
Foot[8] 1.0002 0.0002 0.9998 1.0005 0.0000 0.0000 3023.0210 1364.1188 0.9991
Foot[9] 1.0001 0.0002 0.9998 1.0005 0.0000 0.0000 2977.4864 1505.8880 0.9993
Foot[10] 1.0002 0.0002 0.9998 1.0005 0.0000 0.0000 3079.4165 1473.2620 0.9993
Foot[11] 1.0002 0.0002 0.9998 1.0005 0.0000 0.0000 3239.4060 1578.0688 1.0010
Foot[12] 1.0002 0.0002 0.9999 1.0006 0.0000 0.0000 3386.2669 1531.7540 0.9995
Foot[13] 1.0002 0.0002 0.9998 1.0005 0.0000 0.0000 3435.6724 1478.0745 1.0000
Foot[14] 1.0003 0.0002 1.0000 1.0006 0.0000 0.0000 3230.3339 1557.3770 0.9992
Foot[15] 1.0001 0.0002 0.9997 1.0004 0.0000 0.0000 3300.5015 1527.3001 0.9995
T0[0] 0.0002 0.0002 -0.0003 0.0007 0.0000 0.0000 2816.1484 1235.8575 1.0023
Tgrad[0] 0.0001 0.0001 -0.0000 0.0002 0.0000 0.0000 2356.5919 1226.1130 1.0027
Tgrad[1] 0.0000 0.0001 -0.0001 0.0002 0.0000 0.0000 2436.4942 1337.8885 1.0034
Tgrad[2] 0.0001 0.0001 -0.0000 0.0002 0.0000 0.0000 2598.0396 1173.4813 1.0059
Tgrad[3] 0.0001 0.0001 -0.0000 0.0002 0.0000 0.0000 2526.2233 1389.3658 1.0081
Tgrad[4] 0.0001 0.0001 -0.0000 0.0002 0.0000 0.0000 2349.6528 1195.2554 1.0013
Tgrad[5] 0.0001 0.0001 -0.0000 0.0002 0.0000 0.0000 2487.6681 1328.1977 1.0023
Tgrad[6] 0.0001 0.0001 -0.0000 0.0002 0.0000 0.0000 2483.8617 1426.2372 1.0023
Tgrad[7] 0.0001 0.0001 -0.0000 0.0002 0.0000 0.0000 3214.8212 1318.6139 1.0023
Tgrad[8] 0.0001 0.0001 -0.0000 0.0002 0.0000 0.0000 2798.9538 1572.8386 1.0008
Tgrad[9] 0.0001 0.0001 -0.0001 0.0002 0.0000 0.0000 3180.8237 1534.3567 1.0016
Tgrad[10] 0.0000 0.0001 -0.0001 0.0002 0.0000 0.0000 3166.4570 1327.9221 1.0004
Tgrad[11] 0.0001 0.0001 -0.0000 0.0002 0.0000 0.0000 3160.4413 1458.0786 1.0008
Tgrad[12] 0.0001 0.0001 -0.0001 0.0002 0.0000 0.0000 2654.7390 1357.1880 1.0000
Tgrad[13] 0.0001 0.0001 -0.0001 0.0002 0.0000 0.0000 2377.5940 1387.9053 1.0023
Tgrad[14] 0.0001 0.0001 -0.0001 0.0002 0.0000 0.0000 2356.6214 1103.5015 1.0030
Tgrad[15] 0.0000 0.0001 -0.0001 0.0002 0.0000 0.0000 2531.8985 1397.1884 1.0000
a[0] 7.9816 0.0523 7.8880 8.0839 0.0010 0.0007 2937.6388 1452.3881 0.9999
b[0] 0.5007 0.0089 0.4848 0.5177 0.0002 0.0001 3316.3946 1213.1770 1.0036
log_h[0] -7.6997 0.1251 -7.9167 -7.4493 0.0031 0.0022 1643.7101 1425.3674 1.0014
log_l_l[0] 6.8737 0.1269 6.6408 7.1205 0.0031 0.0022 1687.7423 1520.1483 1.0027
log_l_t[0] -4.5493 0.0934 -4.7295 -4.3838 0.0022 0.0015 1905.9643 1064.3458 1.0013
log_sigma[0] -7.6621 0.0797 -7.8056 -7.5083 0.0014 0.0010 3194.5955 1481.8683 1.0011
log_sigma[1] -7.6336 0.0770 -7.7837 -7.4902 0.0014 0.0010 3230.0714 1334.5348 1.0022
log_sigma[2] -7.5084 0.0763 -7.6390 -7.3544 0.0015 0.0011 2684.2917 1280.8033 1.0018
log_sigma[3] -7.6340 0.0765 -7.7732 -7.4931 0.0014 0.0010 2953.9162 1453.1273 0.9998
log_sigma[4] -7.5690 0.0764 -7.7193 -7.4303 0.0014 0.0010 2970.6738 1302.8707 1.0005
log_sigma[5] -7.4744 0.0746 -7.6176 -7.3403 0.0015 0.0011 2548.1874 1099.9578 0.9994
log_sigma[6] -7.4936 0.0740 -7.6205 -7.3498 0.0015 0.0011 2500.4812 1358.2160 1.0002
log_sigma[7] -7.6422 0.0785 -7.7816 -7.4839 0.0014 0.0010 3221.6164 981.9666 0.9999
log_sigma[8] -7.5316 0.0737 -7.6552 -7.3829 0.0013 0.0009 3126.1902 1433.0986 1.0064
log_sigma[9] -7.5061 0.0734 -7.6345 -7.3638 0.0014 0.0010 2939.1994 1588.6753 0.9993
log_sigma[10] -7.5260 0.0757 -7.6643 -7.3824 0.0014 0.0010 2940.2757 1072.3901 1.0007
log_sigma[11] -7.5614 0.0749 -7.7002 -7.4192 0.0015 0.0011 2475.9589 1273.0881 0.9997
log_sigma[12] -7.6051 0.0720 -7.7401 -7.4760 0.0014 0.0010 2738.3931 1411.4207 1.0020
log_sigma[13] -7.6576 0.0763 -7.7999 -7.5271 0.0016 0.0011 2441.4020 1327.6626 1.0002
log_sigma[14] -7.7081 0.0784 -7.8533 -7.5632 0.0016 0.0011 2390.2847 1537.9369 1.0003
log_sigma[15] -7.5013 0.0811 -7.6543 -7.3551 0.0016 0.0011 2801.5700 1482.6434 1.0001
q1[0] 0.6562 0.0222 0.6160 0.6974 0.0004 0.0003 2662.6074 1224.2983 1.0006
q1[1] 0.6104 0.0200 0.5761 0.6509 0.0004 0.0003 3082.8559 1254.7199 1.0049
q1[2] 0.6009 0.0203 0.5647 0.6403 0.0004 0.0003 3073.1412 1202.0142 1.0008
q1[3] 0.5747 0.0207 0.5365 0.6143 0.0004 0.0003 3195.1062 1513.1594 1.0004
q1[4] 0.5652 0.0202 0.5262 0.6025 0.0004 0.0003 2771.1884 1459.2294 1.0004
q1[5] 0.5342 0.0199 0.4938 0.5689 0.0004 0.0003 2735.5569 1276.6026 1.0011
q1[6] 0.5180 0.0189 0.4804 0.5512 0.0003 0.0002 2934.1109 1475.9894 1.0014
q1[7] 0.4991 0.0176 0.4640 0.5313 0.0003 0.0002 2753.6192 1368.6650 1.0010
q1[8] 0.4792 0.0190 0.4404 0.5124 0.0004 0.0002 2828.6046 1318.5772 1.0070
q1[9] 0.4680 0.0183 0.4334 0.5006 0.0003 0.0002 3715.3402 1319.4225 1.0001
q1[10] 0.4431 0.0183 0.4102 0.4790 0.0003 0.0002 3137.4949 1441.6509 1.0021
q1[11] 0.4282 0.0178 0.3936 0.4605 0.0003 0.0002 3235.9747 1378.4871 1.0000
q1[12] 0.4090 0.0181 0.3760 0.4442 0.0003 0.0003 2695.6963 1216.5499 1.0031
q1[13] 0.3938 0.0170 0.3624 0.4249 0.0003 0.0002 3234.3868 1410.7988 1.0061
q1[14] 0.3669 0.0163 0.3362 0.3974 0.0003 0.0002 3298.4932 1433.1160 1.0023
q1[15] 0.3621 0.0164 0.3313 0.3930 0.0003 0.0002 3036.7235 1324.9190 0.9996
q2[0] 0.4356 0.0054 0.4255 0.4456 0.0001 0.0001 2492.3464 1160.0995 1.0011
q2[1] 0.4329 0.0056 0.4220 0.4428 0.0001 0.0001 2947.5023 1259.2959 1.0009
q2[2] 0.4263 0.0053 0.4165 0.4362 0.0001 0.0001 3800.9091 1234.8561 1.0002
q2[3] 0.4215 0.0054 0.4113 0.4315 0.0001 0.0001 3032.6797 1552.6290 1.0009
q2[4] 0.4144 0.0058 0.4035 0.4247 0.0001 0.0001 3155.3336 1472.3759 1.0020
q2[5] 0.4096 0.0057 0.3991 0.4202 0.0001 0.0001 3326.7279 1433.9862 1.0011
q2[6] 0.4028 0.0058 0.3913 0.4129 0.0001 0.0001 3146.2362 1441.9461 1.0049
q2[7] 0.3964 0.0055 0.3860 0.4066 0.0001 0.0001 2850.8866 1306.9175 1.0008
q2[8] 0.3900 0.0061 0.3790 0.4017 0.0001 0.0001 3155.2973 1409.5124 1.0015
q2[9] 0.3820 0.0058 0.3713 0.3923 0.0001 0.0001 3129.1055 1505.4143 1.0027
q2[10] 0.3750 0.0058 0.3643 0.3854 0.0001 0.0001 3058.2123 1364.1660 0.9994
q2[11] 0.3671 0.0059 0.3563 0.3777 0.0001 0.0001 3454.1448 1383.9611 1.0016
q2[12] 0.3593 0.0058 0.3489 0.3702 0.0001 0.0001 3026.0483 1460.3061 1.0002
q2[13] 0.3510 0.0061 0.3399 0.3621 0.0001 0.0001 3217.6055 1382.8442 1.0004
q2[14] 0.3434 0.0062 0.3319 0.3549 0.0001 0.0001 3085.8008 1343.1320 1.0002
q2[15] 0.3330 0.0062 0.3211 0.3449 0.0001 0.0001 2988.5567 1338.1487 1.0039
rho[0] 0.0986 0.0014 0.0960 0.1012 0.0000 0.0000 2888.5851 1374.5222 1.0006
rho[1] 0.0996 0.0014 0.0969 0.1022 0.0000 0.0000 3063.8956 1285.9244 0.9998
rho[2] 0.1001 0.0015 0.0973 0.1027 0.0000 0.0000 3270.9619 1454.7833 1.0008
rho[3] 0.1004 0.0014 0.0976 0.1028 0.0000 0.0000 2763.8603 1346.7925 1.0017
rho[4] 0.0997 0.0014 0.0969 0.1023 0.0000 0.0000 3141.6593 1376.1186 0.9999
rho[5] 0.1005 0.0014 0.0978 0.1031 0.0000 0.0000 2532.2877 1472.6989 1.0009
rho[6] 0.1015 0.0014 0.0989 0.1042 0.0000 0.0000 3135.3080 1552.3700 1.0014
rho[7] 0.1014 0.0014 0.0988 0.1039 0.0000 0.0000 2276.6345 1494.8645 1.0000
rho[8] 0.1010 0.0014 0.0985 0.1036 0.0000 0.0000 2267.5925 1556.7423 0.9995
rho[9] 0.1008 0.0014 0.0982 0.1034 0.0000 0.0000 2135.0970 1513.5275 1.0002
rho[10] 0.1006 0.0014 0.0979 0.1031 0.0000 0.0000 2219.9527 1341.5090 0.9999
rho[11] 0.1010 0.0014 0.0982 0.1033 0.0000 0.0000 2186.4506 1356.6147 1.0006
rho[12] 0.1019 0.0013 0.0997 0.1044 0.0000 0.0000 2449.9208 1485.0208 0.9992
rho[13] 0.1013 0.0013 0.0988 0.1037 0.0000 0.0000 2407.0695 1258.6415 1.0001
rho[14] 0.1030 0.0013 0.1004 0.1052 0.0000 0.0000 2506.0257 1313.2882 1.0011
rho[15] 0.1014 0.0014 0.0988 0.1039 0.0000 0.0000 2980.4067 1232.7941 1.0012

The arviz.plot_trace function can be very useful for visualising the marginal posterior distributions of the different parameters as well as to examine the chains and diagnose any possible convergence issues.

trace_plot = az.plot_trace(idata)
plt.tight_layout()
../_images/708bfb5aab6b02fbbe4cf095ebc6cba8eb27e227363adca7457d1e41bb853353.png

corner is great for generating nice corner plots. Unfortunately when fitting many light curves the number of parameters can make visualising the full corner plot quite unwieldly but we can always just plot subsets of parameters instead.

from corner import corner

# Select the parameters to include in the plot
params = ["T0", "a", "rho", "b", "log_h", "log_l_l", "log_l_t"]

# Plot each of the two chains separately
idata_corner1 = idata.sel(chain=[0])
idata_corner2 = idata.sel(chain=[1])

# Plot first chain
fig1 = corner(idata_corner1, smooth = 0.4, var_names = params)

# Plot second chain along with truth values
fig2 = corner(idata_corner2, quantiles=[0.16, 0.5, 0.84], title_fmt = None, title_kwargs={"fontsize": 23},
              label_kwargs={"fontsize": 23}, show_titles=True, smooth = 0.4, color = "r", fig = fig1,
              top_ticks = True, max_n_ticks = 2, labelpad = 0.16, var_names = params,
              truths = p_sim, truth_color = "k",
              )
../_images/cf9794ce27855ba9817651470ef71647641a6643442fadcc79ccd79a0d7f43a6.png

If you want to examine the chains for any parameters you can use idata.posterior[param] to get an xarray.DataArray object and can use the to_numpy method to convert to a NumPy array. Below we retrieve the mean and covariance matrix of the radius ratio parameters and then visualise the transmission spectrum as well as the covariance matrix of the transmission spectrum.

# NumPy array of MCMC samples of shape (N_chains, N_draws, N_l)
rho_chains = idata.posterior["rho"].to_numpy()

# Get the mean radius ratio values averaged over all chains and draws
rho_mean = rho_chains.mean((0, 1))

# Calculate the covariance matrix of each chain and average together
N_chains = rho_chains.shape[0]
rho_cov = jnp.zeros((N_l, N_l))
for i in range(N_chains):
    rho_cov += jnp.cov(rho_chains[0, :, :].T)
rho_cov /= N_chains

# Standard deviation given by sqrt of diagonal of covariance matrix
rho_std_dev = jnp.sqrt(jnp.diag(rho_cov))

# We can plot our recovered spectrum against the simulated true spectrum
plt.errorbar(x_l, rho_mean, yerr = rho_std_dev, fmt = 'k.', label = "Recovered Spectrum")
plt.plot(x_l, p_sim["rho"], 'r--', label = "True Spectrum")

# Select a few samples from the MCMC to help visualise the correlation between values
rho_draws = rho_chains[0, 0:1000:100, :].T
plt.plot(x_l, rho_draws, 'k-', alpha = 0.3)
plt.xlabel("Wavelength ($\AA$)")
plt.ylabel(r"Radius Ratio $\rho$")
plt.legend()
plt.show()

plt.title("Tranmission spectrum covariance matrix")
plt.imshow(rho_cov, extent = [x_l[0], x_l[-1], x_l[-1], x_l[0]])
plt.colorbar()
plt.xlabel("Wavelength ($\AA$)")
plt.ylabel("Wavelength ($\AA$)")
plt.show()
../_images/7b2685f999d708bffe1d370f879d7c41de5d74ecff62ac8f720ba386453cf612.png ../_images/5b4347a05a47d4c2dd25b1694aaffdbb25cb335d9b14da0d017071ae94cdcf13.png

It can be hard to tell when there are significant correlations involved whether the recovered spectrum is actually consistent with the simulated spectrum. The reduced chi-squared statistic \(\chi_r^2\) is useful for determining this. It is calculated using:

\[ \begin{equation} \chi_\mathrm{r}^2 = (\bar{\vec{\rho}}_\mathrm{ret} - \vec{\rho}_\mathrm{inj})^T \mathbf{K}_\mathrm{\vec{\rho}; ret}^{-1} (\bar{\vec{\rho}}_\mathrm{ret} - \vec{\rho}_\mathrm{inj}) / N_l, \end{equation} \]

Where \(\bar{\vec{\rho}}_\mathrm{ret}\) is our retrieved mean transmission spectrum, \(\vec{\rho}_\mathrm{inj}\) is the injected transmission spectrum and \(\mathbf{K}_\mathrm{\vec{\rho}; ret}\) is our retrieved covariance matrix of the transmission spectrum

r = rho_mean - p_sim["rho"]
chi2_r = r.T @ jnp.linalg.inv(rho_cov) @ r / N_l

print("Reduced chi-squared value:", chi2_r)

# The reduced chi-squared distribution always has mean 1 but has different standard deviation
# Depending on the numbers of degrees of freedom
print("Distribution of reduced chi-squared: mu = 1, sigma =", jnp.sqrt(2/N_l))
Reduced chi-squared value: 1.4175618575614424
Distribution of reduced chi-squared: mu = 1, sigma = 0.3535533905932738