# This should install any required dependencies assuming JAX has already been installed
# See https://jax.readthedocs.io/en/latest/installation.html for details on installing JAX for your system
try:
import luas
except ImportError:
!git clone https://github.com/markfortune/luas.git
%cd luas
%pip install .
%cd ..
try:
import jaxoplanet
except ImportError:
%pip install -q jaxoplanet
try:
import pymc
except ImportError:
%pip install -q pymc
try:
import corner
except ImportError:
%pip install -q corner
try:
import arviz as az
except ImportError:
%pip install -q arviz
PyMC Example: Transmission Spectroscopy#
This notebook provides to tutorial on how to use PyMC
to perform spectroscopic transit light curve fitting. We will first go through how to generate transit light curves in jax
, then we will create synthetic noise which will be correlated in both wavelength and time. Finally we will use PyMC
to fit the noise contaminated light curves and recover the input transmission spectrum. You can run this notebook yourself on Google Colab by clicking the rocket icon in the top right corner.
If you have a GPU available (including a GPU on Google Colab) then jax
should detect this and run on it automatically. Running jax.devices()
tells us what jax
has found to run on. If there is a GPU available but it is not detected this could indicate jax
has not been correctly installed for GPU. Also note if running on Google Colab that the T4 GPUs have poor performance at double precision floating point and may not show significant speed-ups, while more modern hardware such as NVIDIA Tesla V100s or better tend to show significant performance improvements
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import os.path
import logging
import jaxoplanet
import arviz as az
import pandas as pd
from astropy.constants import M_sun, R_sun, G
import astropy.units as u
# Useful to set when looking at retrieved values for many parameters
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
# This helps give more information on what PyMC is doing during inference
logging.getLogger().setLevel(logging.INFO)
# Running this at the start of the runtime ensures jax uses 64-bit floating point numbers
# as jax uses 32-bit by default
jax.config.update("jax_enable_x64", True)
# This will list available CPUs/GPUs JAX is picking up
print("Available devices:", jax.devices())
Available devices: [CpuDevice(id=0)]
First we will use jaxoplanet which has its own documentation and tutorials here. Let’s create a single light curve with quadratic limb darkening. We will use the transit_light_curve
function from luas.exoplanet
which we have included below to show what it does.
Note, parameters which will be shared between light curves are assumed to be size-1 arrays here, this will make it easier to implement multi-light curve fitting with PyMC later.
# Calculates solar density in kg/m^3 to convert between fitting for stellar density or a/R*
solar_density = ((M_sun/R_sun**3)/(u.kg/u.m**3)).si
def transit_light_curve(par, t):
"""Uses the package `jaxoplanet <https://github.com/exoplanet-dev/jaxoplanet>`_ to calculate
transit light curves using JAX assuming quadratic limb darkening and a simple circular orbit.
Note:
If using this function then make sure to cite jaxoplanet separately as it is an independent
package from luas.
This particular function will only compute a single transit light curve but JAX's vmap function
can be used to calculate the transit light curve of multiple wavelength bands at once.
It assumes that the central transit time `T0`, period `P`, semi-major axis to stellar ratio a/R* = `a`
and impact parameter `b` are size-1 arrays as this makes it easier to implement with vmap
and PyMC. Feel free to vary this however, for example it can easily be modified to fit for T0 separately
for each light curve.
.. code-block:: python
>>> from luas.exoplanet import transit_light_curve
>>> import jax.numpy as jnp
>>> par = {
>>> ... "T0":0.*jnp.ones(1), # Central transit time (days)
>>> ... "P":3.4*jnp.ones(1), # Period (days)
>>> ... "a":8.2*jnp.ones(1), # Semi-major axis to stellar ratio (aka a/R*)
>>> ... "rho":0.1, # Radius ratio (aka Rp/R* or rho)
>>> ... "b":0.5*jnp.ones(1), # Impact parameter
>>> ... # Uses standard quadratic limb darkening parameterisation:
>>> ... # I(r) = 1 − u1(1 − mu) − u2(1 − mu)^2
>>> ... "u1":0.5, # First quadratic limb darkening coefficient
>>> ... "u2":0.1, # Second quadratic limb darkening coefficient
>>> ... "Foot":1., # Baseline flux out of transit
>>> ... "Tgrad":0. # Gradient in baseline flux (hrs^-1)
>>> }
>>> t = jnp.linspace(-0.1, 0.1, 100)
>>> flux = transit_light_curve(par, t)
Args:
par (PyTree): The transit parameters stored in a PyTree/dictionary (see example above).
t (JAXArray): Array of times to calculate the light curve at.
Returns:
JAXArray: Array of flux values for each time input.
"""
# Calculates stellar density in kg/m^3 using par["a"] = a/R*
# Can modify this function to explicitly fit for the stellar density if desired
rho_s = 3*jnp.pi*par["a"][0]**3/(G.value*(par["P"][0]*86400)**2)
# Creates an object describing the star
# This code actually sets the stellar radius as 1 solar radius
# It gives the density relative to solar density
# This actually gives a different mass for the central star but this does not affect the transit model
# It effectively just creates an analogous system scaled in distance by a factor R_sun/R*
# This avoids having to input a value for R* which is irrelevant for transit calculations
# This does not affect a/R*, Rp/R* or b as they are all dimensionless quantities
central = jaxoplanet.orbits.keplerian.Central(density=rho_s/solar_density,radius=1.)
# Define the planetary body
body = jaxoplanet.orbits.keplerian.Body(
period=par["P"][0],
time_transit=par["T0"][0],
radius=par["rho"],
impact_param=par["b"][0],
eccentricity=0., # Can optionally include eccenticity here
omega_peri = 0.,
)
# Creates an orbit object with both the central `star` object and the `body` planet object
orbit = jaxoplanet.orbits.keplerian.OrbitalBody(central = central, body = body)
# Define light curve function `lc` using the quadratic limb darkening coefficients u1 and u2
lc = jaxoplanet.light_curves.limb_dark.light_curve(orbit, [par["u1"], par["u2"]])
# Calculates the transit light curve flux dip with a baseline of one (default from jaxoplanet is zero)
flux = 1 + lc(t)
# Scales the transit model with a linear baseline
baseline = par["Foot"] + 24*par["Tgrad"]*(t - par["T0"][0])
return baseline*flux
# Use some simple parameters typical of a hot Jupiter
# Note the size-1 arrays on parameters which will later be shared between light curves
mfp = {
"T0":0.*jnp.ones(1), # Central transit time
"P":3.*jnp.ones(1), # Period (days)
"a":8.*jnp.ones(1), # Semi-major axis to stellar ratio aka a/R*
"rho":0.1, # Radius ratio rho aka Rp/R*
"b":0.5*jnp.ones(1), # Impact parameter
"u1":0.5, # First quadratic limb darkening coefficient
"u2":0.1, # Second quadratic limb darkening coefficient
"Foot":1., # Baseline flux out of transit
"Tgrad":0. # Gradient in baseline flux (hrs^-1)
}
# Generate 100 evenly spaced time points (in units of days)
N_t = 100
x_t = jnp.linspace(-0.15, 0.15, N_t)
plt.plot(x_t, transit_light_curve(mfp, x_t), "k.")
plt.xlabel("Time (days)")
plt.ylabel("Normalised Flux")
plt.show()
Now we want to take our jax
function which has been written to generate a 1D light curve in time and instead create separate light curves in each wavelength. We also will want to share some parameters between light curves (e.g. impact parameter b, system scale a/Rs) while varying other parameters for each wavelength (e.g. radius ratio Rp/Rs).
We could of course write a for
loop but for loops can be slow to compile with jax
. A more efficient option is to make use of the jax.vmap
function which allows us to “vectorise” our function from 1D to 2D:
# First we must tell JAX which parameters of the function we want to vary for each light curve
# and which we want to be shared between light curves
transit_light_curve_vmap = jax.vmap(
# First argument is the function to vectorise
transit_light_curve,
# Specify which parameters to share and which to vary for each light curve
in_axes=(
{
# If a parameter is to be shared across each light curve then it should be set to None
"T0":None, "P":None, "a":None, "b":None,
# Parameters which vary in wavelength are given the dimension of the array to expand along
# In this case we are expanding from 0D arrays to 1D arrays so this must be 0
"rho":0, "u1":0, "u2":0, "Foot":0, "Tgrad":0
},
# Also must specify that we will share the time array (the second function parameter)
# between light curves
None,
),
# Specify the output dimension to expand along, this will default to 0 anyway
# Will output extra flux values for each light curve as additional rows
out_axes = 0,
)
# Let's define a wavelength range of our data, this isn't actually used by our mean function
# but we will use it for defining correlation in wavelength later
N_l = 16 # Feel free to vary the number of light curves, even >100 light curves may perform quite efficiently
x_l = np.linspace(4000, 7000, N_l)
# Now we define our 2D transit parameters, let's just assume most parameters are constant in wavelength
# We are letting the limb darkening parameters vary linearly in wavelength, feel free to replace
# with generated limb-darkening parameters from a package of your choice
# Note we are now
mfp_2D = {
"T0":0.*jnp.ones(1), # Central transit time
"P":3.*jnp.ones(1), # Period (days)
"a":8.*jnp.ones(1), # Semi-major axis to stellar ratio aka a/R*
"rho":0.1*jnp.ones(N_l), # Radius ratio rho aka Rp/R*
"b":0.5*jnp.ones(1), # Impact parameter
"u1":jnp.linspace(0.7, 0.4, N_l), # First quadratic limb darkening coefficient
"u2":jnp.linspace(0.1, 0.2, N_l), # Second quadratic limb darkening coefficient
"Foot":1.*jnp.ones(N_l), # Baseline flux out of transit
"Tgrad":0.*jnp.ones(N_l) # Gradient in baseline flux (days^-1)
}
# Call our new vmap of transit_light_curve to simultaneously compute light curves for all wavelengths
transit_model_2D = transit_light_curve_vmap(mfp_2D, x_t)
# Visualise the transit light curves
plt.imshow(transit_model_2D, aspect = 'auto', extent = [x_t[0], x_t[-1], x_l[-1], x_l[0]])
plt.xlabel("Time (days)")
plt.ylabel(r"Wavelength ($\AA$)")
plt.show()
The form of a mean function with luas.GP
is f(par, x_l, x_t)
where par
is a PyTree, x_l
is a JAXArray of inputs that lie along the wavelength/vertical dimension of the data (e.g. an array of wavelength values) and x_t
is a JAXArray of inputs that lie along the time/horizontal dimension (e.g. an array of timestamps). Our transit model does not require wavelength values to be computed but we still write our wrapper function to be of this form.
We will also switch to the Kipping (2013) limb darkening parameterisation as it makes it easier to place simple bounds on these parameters and has the benefit of placing a uniform prior on the physically allowed space of limb darkening parameters.
Also feel free to use the luas.exoplanet.transit_2D
function instead which is similar to the below function with the exception that it fits for transit depth \(d = \rho^2\) instead of radius ratio \(\rho\)
from luas.exoplanet import ld_to_kipping, ld_from_kipping
def transit_light_curve_2D(p, x_l, x_t):
# vmap requires that we only input the parameters which have been explicitly defined how they vectorise
transit_params = ["T0", "P", "a", "rho", "b", "Foot", "Tgrad"]
mfp = {k:p[k] for k in transit_params}
# Calculate limb darkening coefficients from the Kipping (2013) parameterisation.
mfp["u1"], mfp["u2"] = ld_from_kipping(p["q1"], p["q2"])
# Use the vmap of transit_light_curve to calculate a 2D array of shape (M, N) of flux values
# For M wavelengths and N time points.
return transit_light_curve_vmap(mfp, x_t)
# Switch to Kipping parameterisation
if "u1" in mfp_2D:
mfp_2D["q1"], mfp_2D["q2"] = ld_to_kipping(mfp_2D["u1"], mfp_2D["u2"])
del mfp_2D["u1"]
del mfp_2D["u2"]
# This should produce the same result as transit_light_curve_vmap as we have only changed the limb darkening parameterisation
M = transit_light_curve_2D(mfp_2D, x_l, x_t)
Now that we have our mean function created, it’s time to start building a kernel function. We will try keep things simple and use a squared exponential kernel for correlation in both time and wavelength, and white noise which varies in amplitude between different light curves.
This kernel contains a correlated noise component with height scale \(h\), length scale in wavelength \(l_\mathrm{\lambda}\) and length scale in time \(l_t\). There is also a white noise term which can have different white noise amplitudes at different wavelengths.
We will need to write this kernel as separate wavelength and time kernel functions multiplied together satisfying:
We can do this by choosing the equations below for our kernel functions which generate each component matrix.
Note that for numerical stability reasons it’s important that the \(S_l\) and \(S_t\) matrices are well-conditioned (i.e. can be stably inverted) while \(K_l\) and \(K_t\) do not need to be well-conditioned (or even invertible), semi-positive definite matrices are okay for these. The matrices which contain values added along the diagonal should be well-conditioned as this “regularises” the matrices, while a squared exponential kernel with nothing added along the diagonal is unlikely to be well-conditioned.
from luas import kernels
from luas import LuasKernel, GeneralKernel
# We implement each of these kernel functions below using the luas.kernels module
# for an implementation of the squared exponential kernel
# The wavelength kernel functions take the wavelength regression variable(s) x_l as input (of shape (N_l) or (d_l, N_l))
def Kl_fn(hp, x_l1, x_l2, wn = True):
Kl = jnp.exp(2*hp["log_h"])*kernels.squared_exp(x_l1, x_l2, jnp.exp(hp["log_l_l"]))
return Kl
# The time kernel functions take the time regression variable(s) x_t as input (of shape (N_t) or (d_t, N_t))
def Kt_fn(hp, x_t1, x_t2, wn = True):
return kernels.squared_exp(x_t1, x_t2, jnp.exp(hp["log_l_t"]))
# For both the Sl and St functions we set a decomp attribute to "diag" because they produce diagonal matrices
# This speeds up the log likelihood calculations as it tells luas these matrices are easy to eigendecompose
# But don't do this for the Kl and Kt functions even if they produce diagonal matrices unless you know what you are doing
# This is because you are telling luas that the transformations of Kl and Kt are diagonal, not Kl and Kt themselves
# If the wn keyword argument is True then white noise should be included (doesn't affect most matrices)
# This is used by gp.predict when performing Gaussian process prediction
# Note it does not matter that Sl is not invertible without white noise
def Sl_fn(hp, x_l1, x_l2, wn = True):
Sl = jnp.zeros((x_l1.shape[-1], x_l2.shape[-1]))
if wn:
# If we are including white noise then safe to assume Sl is a square matrix for any calculations in luas.GP
Sl += jnp.diag(jnp.exp(2*hp["log_sigma"])) # Assumes hp["log_sigma"] is an array of size N_l
return Sl
Sl_fn.decomp = "diag" # Sl is a diagonal matrix
def St_fn(p, x_t1, x_t2, wn = True):
return jnp.eye(x_t1.shape[-1])
St_fn.decomp = "diag" # St is a diagonal matrix
# Build a LuasKernel object using these component kernel functions
# The full covariance matrix applied to the data will be K = Kl KRON Kt + Sl KRON St
kernel = LuasKernel(Kl = Kl_fn, Kt = Kt_fn, Sl = Sl_fn, St = St_fn,
# Can select whether to use previously calculated eigendecompositions when running MCMC
# Performs an additional check in each step to see if each component covariance matrix has changed since last step
# Useful when doing blocked Gibbs or if fixing some hyperparameters
use_stored_values = True,
)
# We can also add a general kernel using the same kernel function as the LuasKernel but without the kronecker product optimisations
# We will use this to show they produce the same answers and to compare runtimes
# While it takes the LuasKernel.K function, this simply takes Kl_fn, Kt_fn, Sl_fn and St_fn and kronecker products the results together
# So it builds the same covariance matrix but the log likelihood calculations are completely different
general_kernel = GeneralKernel(K = kernel.K)
The LuasKernel.visualise_covariance_matrix
method should be a useful way of visualising each of the four component matrices and ensure everything you have written looks right. It uses matplotlib.pyplot.pcolormesh
to visualise the covariance matrices which can produce weird looking plots for large matrices however.
hp = {
"log_h":np.log(2e-3),
"log_l_l":np.log(1000.),
"log_l_t":np.log(0.011),
"log_sigma":np.log(5e-4)*np.ones(N_l),
}
kernel.visualise_covariance_matrix(hp, x_l, x_t);
We can test the numerical stability of both the LuasKernel
and GeneralKernel
with a simple test where we multiply a random normal vector by the inverse of the covariance matrix followed by multiplying by the covariance matrix. This should result in the original random normal vector we started with. We should expect some level of floating point errors but they are likely to be negligible relative to the noise level in a dataset. So far I have yet to see the LuasKernel
perform poorly for any valid covariance matrix. If having issues here it is worth making sure that each covariance matrix generates a symmetric covariance matrix and that \(S_l\) and \(S_t\) are invertible.
# Generate a random normal vector
random_vec = np.random.normal(size = (N_l, N_t))
# Calculate alpha = K_inv @ random_vec
alpha = kernel.K_inv_by_vec(hp, x_l, x_t, random_vec)
# Calculate K @ alpha which should equal random_vec
K_K_inv_random_vec = kernel.K_by_vec(hp, x_l, x_t, alpha)
# Calculate the numerical errors from the original vector
calc_err = random_vec - K_K_inv_random_vec
print(f"Largest deviations (LuasKernel): {calc_err.min()}, {calc_err.max()}")
# Perform the same calculation with the general_kernel which builds the full covariance matrix
# and uses Cholesky decomposition for inverting the covariance matrix
K_inv_random_vec = general_kernel.K_inv_by_vec(hp, x_l, x_t, random_vec)
K_K_inv_random_vec = general_kernel.K_by_vec(hp, x_l, x_t, K_inv_random_vec)
calc_err = random_vec - K_K_inv_random_vec
print(f"Largest deviations (GeneralKernel): {calc_err.min()}, {calc_err.max()}")
Largest deviations (LuasKernel): -3.4572344986827375e-13, 4.5047299224165727e-13
Largest deviations (GeneralKernel): -3.0453417565468044e-13, 3.235189893757706e-13
Take a random noise draw from this covariance matrix. Try playing around with these values to see the effect varying each parameter has on the noise generated.
hp = {"log_h":jnp.log(2e-3),
"log_l_l":jnp.log(1000.),
"log_l_t":jnp.log(0.011),
"log_sigma":jnp.log(5e-4)*jnp.ones(N_l),}
sim_noise = kernel.generate_noise(hp, x_l, x_t)
plt.imshow(sim_noise, aspect = 'auto', extent = [x_t[0], x_t[-1], x_l[-1], x_l[0]])
plt.xlabel("Time (days)")
plt.ylabel(r"Wavelength ($\AA$)")
plt.show()
Combining our light curves and noise model we can generate synthetic light curves contaminated by systematics correlated in time and wavelength
def plot_lightcurves(x_t, M, Y, sep = 0.008):
"""Quick function to visualise light curves and the residuals after subtraction of transit model
"""
N_l = x_l.shape[-1]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 12), sharey = True)
for i in range(N_l):
ax1.plot(x_t, Y[i, :] + np.arange(0, -sep*N_l, -sep)[i], 'bo', ms = 3)
ax1.plot(x_t, M[i, :] + np.arange(0, -sep*N_l, -sep)[i], 'k-', ms = 3)
ax2.plot(x_t, Y[i, :] - M[i, :] + np.arange(1, 1-sep*N_l, -sep)[i], 'b.', ms = 3)
ax2.plot(x_t, np.arange(1, 1-sep*N_l, -sep)[i]*np.ones_like(x_t), 'k-', ms = 3)
ax1.set_xlabel("Time (days)")
ax2.set_xlabel("Time (days)")
ax1.set_ylabel("Relative flux")
plt.tight_layout()
u1_sim, u2_sim = np.linspace(0.7, 0.4, N_l), np.linspace(0.1, 0.2, N_l)
q1_sim, q2_sim = ld_to_kipping(u1_sim, u2_sim)
# All parameters used to generate the synthetic data set with noise
# PyMC will not work with jax arrays as inputs so we are using NumPy arrays
# Our PyMC wrapper also assumes that each input is an array so convert floats to size 1 NumPy arrays
p_sim = {
# Mean function parameters
"T0":0.*np.ones(1), # Central transit time
"P":3.*np.ones(1), # Period (days)
"a":8.*np.ones(1), # Semi-major axis to stellar ratio aka a/R*
"rho":0.1*np.ones(N_l), # Radius ratio rho aka Rp/R* for each wavelength
"b":0.5*np.ones(1), # Impact parameter
"q1":q1_sim, # First quadratic limb darkening coefficient for each wavelength
"q2":q2_sim, # Second quadratic limb darkening coefficient for each wavelength
"Foot":1.*np.ones(N_l), # Baseline flux out of transit for each wavelength
"Tgrad":0.*np.ones(N_l), # Gradient in baseline flux for each wavelength (days^-1)
# Hyperparameters
"log_h":np.log(5e-4)*np.ones(1), # log height scale
"log_l_l":np.log(1000.)*np.ones(1), # log length scale in wavelength
"log_l_t":np.log(0.011)*np.ones(1), # log length scale in time
"log_sigma":np.log(5e-4)*np.ones(N_l), # log white noise amplitude for each wavelength
}
transit_signal = transit_light_curve_2D(p_sim, x_l, x_t)
sim_noise = kernel.generate_noise(p_sim, x_l, x_t)
Y = transit_signal*(1 + sim_noise)
plot_lightcurves(x_t, transit_signal, Y)
We may also want to define a logPrior function. While PyMC can also be used to define priors on parameters, it can be useful to let luas.GP
handle non-uniform priors when it comes to MCMC tuning (as will be shown later).
The logPrior
function input to luas.GP
should be written in JAX and must be of the form logPrior(params)
i.e. it may only take the PyTree of mean function parameters and hyperparameters as input.
a_mean = p_sim["a"]
a_std = 0.1
b_mean = p_sim["b"]
b_std = 0.01
# Set some limb darkening priors, normally these might be generated from a package like LDTk
u1_mean = u1_sim
u1_std = 0.01
u2_mean = u2_sim
u2_std = 0.01
def logPrior(p):
logPrior = -0.5*((p["a"] - a_mean)/a_std)**2
logPrior += -0.5*((p["b"] - b_mean)/b_std)**2
u1, u2 = ld_from_kipping(p["q1"], p["q2"])
u1_priors = -0.5*((u1 - u1_mean)/u1_std)**2
u2_priors = -0.5*((u2 - u2_mean)/u2_std)**2
logPrior += u1_priors.sum() + u2_priors.sum()
return logPrior.sum()
print("Log prior at simulated values:", logPrior(p_sim))
Log prior at simulated values: -1.6370404527291505e-28
We now have enough to define our luas.GP
object and use it to try recover the original transmission signal injected into the correlated noise.
from luas import GP
from copy import deepcopy
# Initialise our GP object
# Make sure to include the mean function and log prior function if you're using them
gp = GP(kernel, # Kernel object to use
x_l, # Regression variable(s) along wavelength/vertical dimension
x_t, # Regression variable(s) along time/horizontal dimension
mf = transit_light_curve_2D, # (optional) mean function to use, defaults to zeros
logPrior = logPrior # (optional) log prior function, defaults to zero
)
# Initialise our starting values as the true simulated values
p_initial = deepcopy(p_sim)
# Convenient function for plotting the data and the GP fit to the data
gp.plot(p_initial, Y)
print("Starting log posterior value:", gp.logP(p_initial, Y))
Starting log posterior value: 9759.09679679811
Let’s compare the runtime speed-up of using the optimisations in the LuasKernel
. Comparing runtimes with JAX is a little bit more involved than normal. First we want to pre-compile the functions using jax.jit
so that we aren’t timing how long each function takes to compile. We then must make sure we run the JIT-compiled functions once before benchmarking as the compilation step will be done based on the array sizes given in that first run. We then can time the JIT-compiled functions but make sure to include block_until_ready
due to jax
using asynchronous dispatch (i.e. it waits until the calculation has actually finished).
Note that while the runtime improvement for this small data set may not be substantial, the methods have very different scaling with GeneralKernel.logP
scaling as \(\mathcal{O}(N_l^3 N_t^3)\) and LuasKernel.logP
scaling as \(\mathcal{O}(N_l^3 + N_t^3 + N_l N_t (N_l + N_t))\), so for larger data sets the difference may be much larger. GeneralKernel.logP
may also crash for larger data sets due to its higher memory requirements.
# Create a similar GP object but using general Cholesky calculations in GeneralKernel instead of using LuasKernel
gp_general = GP(general_kernel, # Kernel object to use
x_l, # Regression variable(s) along wavelength/vertical dimension
x_t, # Regression variable(s) along time/horizontal dimension
mf = transit_light_curve_2D, # (optional) mean function to use, defaults to zeros
logPrior = logPrior # (optional) log prior function, defaults to zero
)
# Let's JIT compile each function before running (this should already be done in GP.py but is good practice anyway)
general_logP_jit = jax.jit(gp_general.logP)
luas_logP_jit = jax.jit(gp.logP)
# Run each calculation before benchmarking, note they should give the same answer up to floating point errors
print("GeneralKernel log posterior Calculation:", general_logP_jit(p_initial, Y))
print("LuasKernel log posterior Calculation:", luas_logP_jit(p_initial, Y))
# Benchmark each method
print("\nRuntime for GeneralKernel: ")
%timeit general_logP_jit(p_initial, Y).block_until_ready()
print("Runtime for LuasKernel: ")
%timeit luas_logP_jit(p_initial, Y).block_until_ready()
GeneralKernel log posterior Calculation: 9759.096796798116
LuasKernel log posterior Calculation: 9759.09679679811
Runtime for GeneralKernel:
42.1 ms ± 6.91 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Runtime for LuasKernel:
2.31 ms ± 42.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Now let’s begin using PyMC to perform a best-fit. We’ve set the true simulated values as our initial values so we should already be close to the optimal log posterior value but for real data this won’t always be the case. Similar to the approach in Fortune et al. (2024) we will not be performing any white light curve analysis but instead will be joint-fitting all spectroscopic light curves simultaneously.
import pymc as pm
from luas.pymc_ext import LuasPyMC
# Let's define some bounds
min_log_l_l = np.log(np.diff(x_l).min())
max_log_l_l = np.log(50*(x_l[-1] - x_l[0]))
min_log_l_t = np.log(np.diff(x_t).min())
max_log_l_t = np.log(3*(x_t[-1] - x_t[0]))
# Note with PyMC that the parameter bounds must also be NumPy arrays
param_bounds = {
# Bounds from Kipping (2013) at just between 0 and 1
"q1":[np.array([0.]*N_l), np.array([1.]*N_l)],
"q2":[np.array([0.]*N_l), np.array([1.]*N_l)],
# Can optionally include bounds on other mean function parameters
# but often they will be well constrained by the data
"rho":[np.array([0.]), np.array([1.]*N_l)],
# Sometimes prior bounds on hyperparameters are important for sampling
# However their choice can sometimes affect the results so use with caution
"log_h": [np.log(1e-6)*np.ones(1), np.log(1)*np.ones(1)],
"log_l_l": [min_log_l_l*np.ones(1), max_log_l_l*np.ones(1)],
"log_l_t": [min_log_l_t*np.ones(1), max_log_l_t*np.ones(1)],
"log_sigma":[np.log(1e-6)*np.ones(N_l), np.log(1e-2)*np.ones(N_l)],
}
# Make a wrapper function which returns a PyMC model with a given set of fixed parameters and observations Y
def transit_model(p_fixed, Y):
with pm.Model() as model:
# Makes of copy of any parameters to be kept fixed during sampling
var_dict = deepcopy(p_fixed)
# Specify the parameters we've given bounds for
var_dict["rho"] = pm.Uniform('rho', lower=param_bounds["rho"][0],
upper=param_bounds["rho"][1], shape=N_l)
var_dict["q1"] = pm.Uniform('q1', lower=param_bounds["q1"][0],
upper=param_bounds["q1"][1], shape=N_l)
var_dict["q2"] = pm.Uniform('q2', lower=param_bounds["q2"][0],
upper=param_bounds["q2"][1], shape=N_l)
var_dict["log_h"] = pm.Uniform("log_h", lower=param_bounds["log_h"][0],
upper=param_bounds["log_h"][1], shape=1)
var_dict["log_l_l"] = pm.Uniform("log_l_l", lower = param_bounds["log_l_l"][0],
upper = param_bounds["log_l_l"][1], shape=1)
var_dict["log_l_t"] = pm.Uniform("log_l_t", lower = param_bounds["log_l_t"][0],
upper = param_bounds["log_l_t"][1], shape=1)
var_dict["log_sigma"] = pm.Uniform('log_sigma', lower=param_bounds["log_sigma"][0],
upper=param_bounds["log_sigma"][1], shape=N_l)
# Specify the unbounded parameters
var_dict["T0"] = pm.Flat('T0', shape=1)
var_dict["a"] = pm.Flat('a', shape=1)
var_dict["b"] = pm.Flat('b', shape=1)
var_dict["Foot"] = pm.Flat('Foot', shape=N_l)
var_dict["Tgrad"] = pm.Flat('Tgrad', shape=N_l)
# PyMC wrapper for luas log posterior calculations
# Requires the gp object, a dictionary of each model parameter and the observations Y
LuasPyMC("log_like", gp = gp, var_dict = var_dict, Y = Y)
# Will need to return both the model and the model variables for inference
return model, var_dict
# Initialise our model, p_initial will specify any fixed values like the period P
model, var_dict = transit_model(p_initial, Y)
Now we are all set up to start performing a best-fit using PyMC
.
# PyMC requires the dictionary of starting values to only include variables in the model
# So we must remove the period parameter P as we keep it fixed
p_pymc = deepcopy(p_initial)
del p_pymc["P"]
# Use PyMC's maximum posteriori optimisation function
map_estimate = pm.find_MAP(
model = model, # PyMC model to optimise
include_transformed = False, # If this is true it will also output the PyMC transformed values of bounded parameters
start = p_pymc, # Starting point of optimisations
maxeval = 5000, # Maximum steps to run (normally will converge before this)
)
# Create a new dictionary of optimised values which includes our fixed parameters
p_opt = deepcopy(p_initial)
p_opt.update(map_estimate)
print("Starting log posterior value:", gp.logP(p_initial, Y))
print("New optimised log posterior value:", gp.logP(p_opt, Y))
MAP ━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16% 0:00:25 logp = 9,753.2, ||grad|| = 332.76
Starting log posterior value: 9759.09679679811
New optimised log posterior value: 9794.993893939803
Although this probably won’t be needed for the simulated data and will likely clip no data points, luas.GP
comes with a sigma_clip
method which performs 2D Gaussian process regression and can clip outliers that deviate from a given significance value and replace them with the GP predictive mean at those locations (note we need to maintain a complete grid structure and cannot remove these data points).
# This function will return a JAXArray of the same shape as Y
# but with outliers replaced with interpolated values
Y_clean = gp.sigma_clip(p_opt, # Make sure to perform sigma clipping using a good fit to the data
Y, # Observations JAXArray
5. # Significance level in standard deviations to clip at
)
Number of outliers clipped = 0
For MCMC tuning with large numbers of parameters, it can be very helpful to use the Laplace approximation to select a good choice of tuning matrix or “mass matrix” for No U-Turn Sampling (NUTS). This can often return quite accurate approximations of the covariance matrix of the posterior - especially when most of the parameters are well constrained.
If outliers were clipped then optimisation may need to be ran again on the cleaned data before running this step as the Laplace approximation should be performed at the maximum of the posterior. Also note that we use the gp.laplace_approx_with_bounds
method instead of gp.laplace_approx
because we have bounds on some of our parameters which means PyMC
will perform a transformation on these parameters we would like our Laplace approximation to take account of. See the gp.laplace_approx_with_bounds
documentation for more details.
# Returns the covariance matrix returned by the Laplace approximation
# Also returns a list of parameters which is the order the array is in
# This matches the way jax.flatten_util.ravel_pytree will sort the parameter PyTree into
cov_mat, ordered_param_list = gp.laplace_approx_with_bounds(
p_opt, # Make sure to use best-fit values
Y_clean, # The observations being fit
param_bounds, # Specify the same bounds that will be used for the MCMC
fixed_vars = ["P"], # Make sure to specify fixed parameters as otherwise they are marginalised over
return_array = True, # May optionally return a nested PyTree if set to False which can be more readable
regularise = True, # Often necessary to regularise values that return negative covariance
large = False, # Setting this to True is more memory efficient which may be needed for large data sets
)
# This function will output information on what regularisation has been performed
# And will mention if there are remaining negative values along the diagonal of the covariance matrix
# It does not however check if the covariance matrix is invertible
No regularisation needed to remove negative values along diagonal of covariance matrix.
Build our model again using the same parameters being varied and sample all parameters using NUTS. You may want to run more than 2 chains, each chain takes about 1 minute to run on my M1 Macbook.
PyMC appears to be unable to run JAX functions in parallel which is why cores = 1
has been set. If wanting to run in parallel then the current suggestion is to run separate programs in parallel which each runs a chain with PyMC
and save each arviz
inference data object separately. These can later by loaded in and combined with arviz.concat
# Initialise our PyMC model
model, var_dict = transit_model(p_opt, Y_clean)
# The NUTS step takes as input the variables created when initialising the model
# We also sort these variables in the same order our Laplace approximated covariance matrix is in
NUTS_model_vars = [var_dict[par] for par in ordered_param_list]
NUTS_step = pm.NUTS(NUTS_model_vars, scaling = cov_mat, is_cov = True, model = model)
# Begin MCMC sampling
idata = pm.sample(
model = model, # PyMC model to sample
step = NUTS_step, # Sampling steps to use (can be list for blocked Gibbs sampling)
initvals = map_estimate, # Starting point of inference (will jitter around this location)
draws = 1000, # Number of samples from MCMC post warm-up
tune = 1000, # Number of tuning steps
chains = 2, # Number of chains to run
cores = 1, # This will probably fail if not set to 1 as I don't think PyMC can parallelise JAX functions
)
# Saves the inference object
#idata.to_json("MCMC_chains.json");
Sampling chain 1, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:00:56
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 109 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
Print a summary of the samples using arviz
. Important values to look at for convergence are that the effective sample size of the bulk of the distribution (ess_bulk
) and the tail of the distribution (ess_tail
) are at least ~500-1000 and that the Gelman-Rubin r_hat
statistic is less than ~1.01 for each parameter.
az.summary(idata, round_to = 4)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Foot[0] | 1.0003 | 0.0002 | 1.0000 | 1.0006 | 0.0000 | 0.0000 | 2568.8559 | 1421.7658 | 1.0084 |
Foot[1] | 1.0001 | 0.0002 | 0.9997 | 1.0004 | 0.0000 | 0.0000 | 2511.1392 | 1183.4623 | 1.0027 |
Foot[2] | 1.0000 | 0.0002 | 0.9997 | 1.0004 | 0.0000 | 0.0000 | 2494.9458 | 1401.6450 | 1.0047 |
Foot[3] | 1.0000 | 0.0002 | 0.9997 | 1.0004 | 0.0000 | 0.0000 | 2383.1096 | 1296.7246 | 1.0067 |
Foot[4] | 1.0002 | 0.0002 | 0.9998 | 1.0005 | 0.0000 | 0.0000 | 2481.9216 | 1257.1202 | 1.0019 |
Foot[5] | 1.0002 | 0.0002 | 0.9998 | 1.0005 | 0.0000 | 0.0000 | 2651.9453 | 1336.8312 | 1.0021 |
Foot[6] | 1.0001 | 0.0002 | 0.9997 | 1.0004 | 0.0000 | 0.0000 | 2370.5094 | 1363.7136 | 1.0027 |
Foot[7] | 1.0000 | 0.0002 | 0.9997 | 1.0003 | 0.0000 | 0.0000 | 2332.3084 | 1320.6002 | 0.9997 |
Foot[8] | 1.0001 | 0.0002 | 0.9997 | 1.0004 | 0.0000 | 0.0000 | 2156.5550 | 1116.2035 | 1.0007 |
Foot[9] | 1.0002 | 0.0002 | 0.9998 | 1.0005 | 0.0000 | 0.0000 | 2250.5764 | 1306.3162 | 1.0006 |
Foot[10] | 1.0002 | 0.0002 | 0.9999 | 1.0006 | 0.0000 | 0.0000 | 2336.7156 | 1287.4453 | 1.0002 |
Foot[11] | 1.0001 | 0.0002 | 0.9998 | 1.0005 | 0.0000 | 0.0000 | 2365.4528 | 1421.7490 | 1.0003 |
Foot[12] | 1.0001 | 0.0002 | 0.9998 | 1.0004 | 0.0000 | 0.0000 | 2181.6595 | 1345.3129 | 1.0024 |
Foot[13] | 1.0001 | 0.0002 | 0.9998 | 1.0005 | 0.0000 | 0.0000 | 2411.6309 | 1325.8110 | 1.0003 |
Foot[14] | 1.0001 | 0.0002 | 0.9999 | 1.0005 | 0.0000 | 0.0000 | 2364.2161 | 1435.2439 | 1.0006 |
Foot[15] | 1.0000 | 0.0002 | 0.9997 | 1.0003 | 0.0000 | 0.0000 | 2332.7552 | 1411.0359 | 1.0006 |
T0[0] | 0.0002 | 0.0002 | -0.0003 | 0.0006 | 0.0000 | 0.0000 | 2098.5143 | 1282.5951 | 1.0052 |
Tgrad[0] | -0.0000 | 0.0001 | -0.0002 | 0.0001 | 0.0000 | 0.0000 | 2291.0770 | 1473.1655 | 0.9997 |
Tgrad[1] | -0.0000 | 0.0001 | -0.0002 | 0.0001 | 0.0000 | 0.0000 | 2284.0208 | 1378.1096 | 0.9999 |
Tgrad[2] | 0.0000 | 0.0001 | -0.0001 | 0.0001 | 0.0000 | 0.0000 | 1750.7975 | 1365.5722 | 0.9997 |
Tgrad[3] | 0.0000 | 0.0001 | -0.0001 | 0.0001 | 0.0000 | 0.0000 | 2195.3107 | 1402.9512 | 0.9999 |
Tgrad[4] | -0.0000 | 0.0001 | -0.0001 | 0.0001 | 0.0000 | 0.0000 | 1819.0453 | 1189.8162 | 1.0005 |
Tgrad[5] | 0.0000 | 0.0001 | -0.0001 | 0.0002 | 0.0000 | 0.0000 | 1969.8015 | 1218.6252 | 0.9996 |
Tgrad[6] | -0.0000 | 0.0001 | -0.0001 | 0.0001 | 0.0000 | 0.0000 | 2163.3999 | 1233.2407 | 0.9995 |
Tgrad[7] | 0.0000 | 0.0001 | -0.0001 | 0.0001 | 0.0000 | 0.0000 | 1935.9656 | 1418.8840 | 0.9995 |
Tgrad[8] | -0.0000 | 0.0001 | -0.0001 | 0.0001 | 0.0000 | 0.0000 | 2093.1542 | 1113.8053 | 0.9998 |
Tgrad[9] | -0.0000 | 0.0001 | -0.0002 | 0.0001 | 0.0000 | 0.0000 | 2016.9345 | 1316.5325 | 1.0001 |
Tgrad[10] | -0.0001 | 0.0001 | -0.0002 | 0.0001 | 0.0000 | 0.0000 | 2187.6532 | 1302.8707 | 1.0006 |
Tgrad[11] | -0.0001 | 0.0001 | -0.0002 | 0.0000 | 0.0000 | 0.0000 | 2222.5749 | 1437.9308 | 1.0003 |
Tgrad[12] | -0.0001 | 0.0001 | -0.0002 | 0.0000 | 0.0000 | 0.0000 | 2338.8928 | 1580.7240 | 1.0004 |
Tgrad[13] | -0.0000 | 0.0001 | -0.0001 | 0.0001 | 0.0000 | 0.0000 | 2373.5204 | 1418.2395 | 1.0004 |
Tgrad[14] | -0.0000 | 0.0001 | -0.0002 | 0.0001 | 0.0000 | 0.0000 | 2333.9502 | 1521.5675 | 1.0011 |
Tgrad[15] | 0.0000 | 0.0001 | -0.0001 | 0.0001 | 0.0000 | 0.0000 | 2444.5417 | 1505.4629 | 1.0018 |
a[0] | 7.9616 | 0.0515 | 7.8680 | 8.0637 | 0.0011 | 0.0007 | 2451.4170 | 1220.6379 | 0.9997 |
b[0] | 0.4964 | 0.0084 | 0.4811 | 0.5123 | 0.0002 | 0.0001 | 2724.9167 | 1133.3360 | 1.0011 |
log_h[0] | -7.6577 | 0.1229 | -7.9029 | -7.4398 | 0.0030 | 0.0021 | 1679.8672 | 1549.2493 | 1.0006 |
log_l_l[0] | 7.1001 | 0.1204 | 6.8611 | 7.3196 | 0.0024 | 0.0017 | 2573.7636 | 1448.2776 | 1.0023 |
log_l_t[0] | -4.6736 | 0.0857 | -4.8258 | -4.5029 | 0.0023 | 0.0016 | 1389.2453 | 1260.0876 | 1.0009 |
log_sigma[0] | -7.5703 | 0.0771 | -7.7148 | -7.4252 | 0.0015 | 0.0011 | 2654.7477 | 1430.2079 | 1.0021 |
log_sigma[1] | -7.5988 | 0.0732 | -7.7402 | -7.4657 | 0.0015 | 0.0011 | 2424.3258 | 1443.5802 | 1.0026 |
log_sigma[2] | -7.5466 | 0.0739 | -7.6881 | -7.4145 | 0.0015 | 0.0011 | 2335.6107 | 1371.4859 | 1.0009 |
log_sigma[3] | -7.5288 | 0.0747 | -7.6704 | -7.3942 | 0.0014 | 0.0010 | 2750.6689 | 1303.2878 | 0.9996 |
log_sigma[4] | -7.6258 | 0.0733 | -7.7623 | -7.4898 | 0.0015 | 0.0010 | 2461.2220 | 1364.8402 | 0.9997 |
log_sigma[5] | -7.5641 | 0.0753 | -7.6908 | -7.4107 | 0.0015 | 0.0011 | 2621.1654 | 1303.6388 | 1.0041 |
log_sigma[6] | -7.6956 | 0.0744 | -7.8344 | -7.5528 | 0.0014 | 0.0010 | 2695.3671 | 1603.2739 | 0.9996 |
log_sigma[7] | -7.6074 | 0.0741 | -7.7488 | -7.4748 | 0.0014 | 0.0010 | 3015.8249 | 1399.4540 | 1.0005 |
log_sigma[8] | -7.7249 | 0.0761 | -7.8730 | -7.5912 | 0.0015 | 0.0010 | 2743.3701 | 1221.7114 | 1.0021 |
log_sigma[9] | -7.4949 | 0.0741 | -7.6316 | -7.3488 | 0.0015 | 0.0011 | 2362.4276 | 1398.3925 | 0.9996 |
log_sigma[10] | -7.5739 | 0.0763 | -7.7179 | -7.4351 | 0.0014 | 0.0010 | 2897.9551 | 1326.5063 | 0.9992 |
log_sigma[11] | -7.5653 | 0.0743 | -7.7108 | -7.4329 | 0.0014 | 0.0010 | 2893.0355 | 1728.3525 | 0.9995 |
log_sigma[12] | -7.5239 | 0.0755 | -7.6712 | -7.3821 | 0.0015 | 0.0011 | 2503.7824 | 1481.2519 | 0.9996 |
log_sigma[13] | -7.5393 | 0.0724 | -7.6743 | -7.4043 | 0.0015 | 0.0010 | 2511.3206 | 1569.7349 | 1.0009 |
log_sigma[14] | -7.5806 | 0.0772 | -7.7282 | -7.4429 | 0.0014 | 0.0010 | 2948.5326 | 1432.7472 | 1.0001 |
log_sigma[15] | -7.6419 | 0.0778 | -7.7815 | -7.4842 | 0.0014 | 0.0010 | 3069.2149 | 1529.3539 | 1.0005 |
q1[0] | 0.6307 | 0.0220 | 0.5914 | 0.6716 | 0.0004 | 0.0003 | 3649.3482 | 1442.8696 | 1.0027 |
q1[1] | 0.6145 | 0.0211 | 0.5773 | 0.6562 | 0.0004 | 0.0003 | 2879.1622 | 1431.0045 | 1.0033 |
q1[2] | 0.6045 | 0.0214 | 0.5658 | 0.6451 | 0.0004 | 0.0003 | 2566.3279 | 1631.7329 | 1.0050 |
q1[3] | 0.5863 | 0.0208 | 0.5489 | 0.6252 | 0.0004 | 0.0003 | 2397.7685 | 1160.4560 | 1.0002 |
q1[4] | 0.5709 | 0.0206 | 0.5346 | 0.6118 | 0.0004 | 0.0003 | 2935.0916 | 1319.7897 | 1.0013 |
q1[5] | 0.5290 | 0.0194 | 0.4908 | 0.5641 | 0.0004 | 0.0003 | 3019.4506 | 1693.9809 | 0.9999 |
q1[6] | 0.5181 | 0.0181 | 0.4844 | 0.5526 | 0.0003 | 0.0002 | 2988.2775 | 1523.5617 | 1.0008 |
q1[7] | 0.4968 | 0.0189 | 0.4580 | 0.5288 | 0.0003 | 0.0002 | 2998.2415 | 1545.8909 | 1.0015 |
q1[8] | 0.4861 | 0.0181 | 0.4518 | 0.5197 | 0.0003 | 0.0002 | 2983.0330 | 1288.0465 | 1.0040 |
q1[9] | 0.4616 | 0.0184 | 0.4287 | 0.4964 | 0.0004 | 0.0003 | 2487.5862 | 1184.0799 | 1.0053 |
q1[10] | 0.4359 | 0.0182 | 0.4049 | 0.4728 | 0.0004 | 0.0003 | 2398.9733 | 1414.1966 | 1.0027 |
q1[11] | 0.4311 | 0.0175 | 0.3973 | 0.4628 | 0.0003 | 0.0002 | 2747.9582 | 1455.2737 | 1.0018 |
q1[12] | 0.4151 | 0.0173 | 0.3845 | 0.4475 | 0.0003 | 0.0002 | 2812.9675 | 1641.3601 | 1.0007 |
q1[13] | 0.3864 | 0.0167 | 0.3555 | 0.4180 | 0.0003 | 0.0002 | 3080.0628 | 1491.9718 | 0.9995 |
q1[14] | 0.3743 | 0.0167 | 0.3435 | 0.4052 | 0.0003 | 0.0002 | 2904.8921 | 1497.7196 | 1.0012 |
q1[15] | 0.3632 | 0.0160 | 0.3331 | 0.3934 | 0.0003 | 0.0002 | 2841.0206 | 1503.1798 | 0.9996 |
q2[0] | 0.4387 | 0.0057 | 0.4281 | 0.4493 | 0.0001 | 0.0001 | 3265.4910 | 1415.4637 | 1.0000 |
q2[1] | 0.4329 | 0.0056 | 0.4227 | 0.4433 | 0.0001 | 0.0001 | 2647.6563 | 1329.0594 | 0.9999 |
q2[2] | 0.4262 | 0.0056 | 0.4166 | 0.4374 | 0.0001 | 0.0001 | 2839.7408 | 1092.2083 | 1.0007 |
q2[3] | 0.4202 | 0.0057 | 0.4093 | 0.4311 | 0.0001 | 0.0001 | 2815.9297 | 1451.6574 | 1.0002 |
q2[4] | 0.4139 | 0.0056 | 0.4041 | 0.4245 | 0.0001 | 0.0001 | 2681.2282 | 1253.3551 | 1.0005 |
q2[5] | 0.4101 | 0.0057 | 0.4002 | 0.4217 | 0.0001 | 0.0001 | 2509.9180 | 1552.9818 | 1.0025 |
q2[6] | 0.4027 | 0.0055 | 0.3924 | 0.4127 | 0.0001 | 0.0001 | 2851.3817 | 1426.3421 | 1.0026 |
q2[7] | 0.3964 | 0.0060 | 0.3852 | 0.4078 | 0.0001 | 0.0001 | 2784.8464 | 1186.5276 | 1.0002 |
q2[8] | 0.3889 | 0.0059 | 0.3789 | 0.4009 | 0.0001 | 0.0001 | 2822.5834 | 1503.2954 | 1.0037 |
q2[9] | 0.3823 | 0.0060 | 0.3719 | 0.3937 | 0.0001 | 0.0001 | 3120.4048 | 1258.7688 | 1.0007 |
q2[10] | 0.3758 | 0.0061 | 0.3649 | 0.3882 | 0.0001 | 0.0001 | 3029.2490 | 1533.7491 | 1.0003 |
q2[11] | 0.3670 | 0.0059 | 0.3552 | 0.3772 | 0.0001 | 0.0001 | 3336.3813 | 1366.7378 | 1.0010 |
q2[12] | 0.3590 | 0.0059 | 0.3478 | 0.3704 | 0.0001 | 0.0001 | 2757.6425 | 1575.6199 | 1.0013 |
q2[13] | 0.3521 | 0.0065 | 0.3404 | 0.3649 | 0.0001 | 0.0001 | 3045.5868 | 1395.6879 | 1.0002 |
q2[14] | 0.3426 | 0.0061 | 0.3319 | 0.3545 | 0.0001 | 0.0001 | 3256.1995 | 1573.3843 | 1.0033 |
q2[15] | 0.3334 | 0.0061 | 0.3222 | 0.3449 | 0.0001 | 0.0001 | 2856.5031 | 1531.6771 | 1.0011 |
rho[0] | 0.1028 | 0.0014 | 0.0998 | 0.1051 | 0.0000 | 0.0000 | 2094.7572 | 1374.5222 | 1.0010 |
rho[1] | 0.1017 | 0.0014 | 0.0992 | 0.1044 | 0.0000 | 0.0000 | 2145.5999 | 1631.5551 | 1.0006 |
rho[2] | 0.1024 | 0.0014 | 0.0997 | 0.1051 | 0.0000 | 0.0000 | 2028.5462 | 1513.9781 | 1.0008 |
rho[3] | 0.1025 | 0.0014 | 0.0998 | 0.1050 | 0.0000 | 0.0000 | 2060.1680 | 1364.2151 | 1.0003 |
rho[4] | 0.1021 | 0.0014 | 0.0996 | 0.1048 | 0.0000 | 0.0000 | 2527.9318 | 1480.8127 | 1.0001 |
rho[5] | 0.1015 | 0.0014 | 0.0989 | 0.1042 | 0.0000 | 0.0000 | 2328.7281 | 1448.0348 | 1.0011 |
rho[6] | 0.1011 | 0.0014 | 0.0985 | 0.1036 | 0.0000 | 0.0000 | 2439.6588 | 1312.9622 | 1.0004 |
rho[7] | 0.1006 | 0.0014 | 0.0981 | 0.1034 | 0.0000 | 0.0000 | 2277.2191 | 1434.8873 | 1.0003 |
rho[8] | 0.1012 | 0.0014 | 0.0984 | 0.1036 | 0.0000 | 0.0000 | 2633.7476 | 1304.6553 | 0.9997 |
rho[9] | 0.1011 | 0.0014 | 0.0984 | 0.1039 | 0.0000 | 0.0000 | 1993.1624 | 1413.2107 | 1.0009 |
rho[10] | 0.1017 | 0.0014 | 0.0990 | 0.1044 | 0.0000 | 0.0000 | 2343.9795 | 1290.0111 | 1.0017 |
rho[11] | 0.1014 | 0.0014 | 0.0989 | 0.1041 | 0.0000 | 0.0000 | 2404.2997 | 1281.3425 | 1.0009 |
rho[12] | 0.1003 | 0.0014 | 0.0975 | 0.1028 | 0.0000 | 0.0000 | 2230.9473 | 1387.6509 | 0.9998 |
rho[13] | 0.1012 | 0.0014 | 0.0985 | 0.1040 | 0.0000 | 0.0000 | 2422.1314 | 1339.8146 | 1.0006 |
rho[14] | 0.1016 | 0.0014 | 0.0990 | 0.1041 | 0.0000 | 0.0000 | 2369.3996 | 1284.6837 | 0.9996 |
rho[15] | 0.1001 | 0.0014 | 0.0974 | 0.1027 | 0.0000 | 0.0000 | 2383.5405 | 1007.5391 | 1.0002 |
The arviz.plot_trace
function can be very useful for visualising the marginal posterior distributions of the different parameters as well as to examine the chains and diagnose any possible convergence issues.
trace_plot = az.plot_trace(idata)
plt.tight_layout()
corner
is great for generating nice corner plots. Unfortunately when fitting many light curves the number of parameters can make visualising the full corner plot quite unwieldly but we can always just plot subsets of parameters instead.
from corner import corner
# Select the parameters to include in the plot
params = ["T0", "a", "rho", "b", "log_h", "log_l_l", "log_l_t"]
# Plot each of the two chains separately
idata_corner1 = idata.sel(chain=[0])
idata_corner2 = idata.sel(chain=[1])
# Plot first chain
fig1 = corner(idata_corner1, smooth = 0.4, var_names = params)
# Plot second chain along with truth values
fig2 = corner(idata_corner2, quantiles=[0.16, 0.5, 0.84], title_fmt = None, title_kwargs={"fontsize": 23},
label_kwargs={"fontsize": 23}, show_titles=True, smooth = 0.4, color = "r", fig = fig1,
top_ticks = True, max_n_ticks = 2, labelpad = 0.16, var_names = params,
truths = p_sim, truth_color = "k",
)
If you want to examine the chains for any parameters you can use idata.posterior[param]
to get an xarray.DataArray
object and can use the to_numpy
method to convert to a NumPy
array. Below we retrieve the mean and covariance matrix of the radius ratio parameters and then visualise the transmission spectrum as well as the covariance matrix of the transmission spectrum.
# NumPy array of MCMC samples of shape (N_chains, N_draws, N_l)
rho_chains = idata.posterior["rho"].to_numpy()
# Get the mean radius ratio values averaged over all chains and draws
rho_mean = rho_chains.mean((0, 1))
# Calculate the covariance matrix of each chain and average together
N_chains = rho_chains.shape[0]
rho_cov = jnp.zeros((N_l, N_l))
for i in range(N_chains):
rho_cov += jnp.cov(rho_chains[0, :, :].T)
rho_cov /= N_chains
# Standard deviation given by sqrt of diagonal of covariance matrix
rho_std_dev = jnp.sqrt(jnp.diag(rho_cov))
# We can plot our recovered spectrum against the simulated true spectrum
plt.errorbar(x_l, rho_mean, yerr = rho_std_dev, fmt = 'k.', label = "Recovered Spectrum")
plt.plot(x_l, p_sim["rho"], 'r--', label = "True Spectrum")
# Select a few samples from the MCMC to help visualise the correlation between values
rho_draws = rho_chains[0, 0:1000:100, :].T
plt.plot(x_l, rho_draws, 'k-', alpha = 0.3)
plt.xlabel(r"Wavelength ($\AA$)")
plt.ylabel(r"Radius Ratio $\rho$")
plt.legend()
plt.show()
plt.title("Tranmission spectrum covariance matrix")
plt.imshow(rho_cov, extent = [x_l[0], x_l[-1], x_l[-1], x_l[0]])
plt.colorbar()
plt.xlabel(r"Wavelength ($\AA$)")
plt.ylabel(r"Wavelength ($\AA$)")
plt.show()
It can be hard to tell when there are significant correlations involved whether the recovered spectrum is actually consistent with the simulated spectrum. The reduced chi-squared statistic \(\chi_r^2\) is useful for determining this. It is calculated using:
Where \(\bar{\vec{\rho}}_\mathrm{ret}\) is our retrieved mean transmission spectrum, \(\vec{\rho}_\mathrm{inj}\) is the injected transmission spectrum and \(\mathbf{K}_\mathrm{\vec{\rho}; ret}\) is our retrieved covariance matrix of the transmission spectrum.
r = rho_mean - p_sim["rho"]
chi2_r = r.T @ jnp.linalg.inv(rho_cov) @ r / N_l
print("Reduced chi-squared value:", chi2_r)
# The reduced chi-squared distribution always has mean 1 but has different standard deviations
# depending on the numbers of degrees of freedom
print("Distribution of reduced chi-squared: mean = 1, std. dev. =", jnp.sqrt(2/N_l))
Reduced chi-squared value: 1.0776260016856025
Distribution of reduced chi-squared: mean = 1, std. dev. = 0.3535533905932738