Working with JAX

Working with JAX#

luas is implemented in jax, which allows it to efficiently calculate gradients of the log-likelihood for use with gradient-based optimisation and MCMC methods. It also makes it easy to run code on a GPU which can provide significant computational advantages.

However, jax can be a bit challenging to work with without some initial context. “How to Think in JAX” and “JAX - The Sharp Bits” are excellent places to learn the specifics of programming in jax. However, if your needs with luas are not too complex then perhaps going to the tutorials section will be sufficient to get up and running. A quick summary of issues you may first encounter with jax is included below.

A Brief Summary of Common Issues with JAX#

jax is designed for pure functional programming and so functions shouldn’t have side-effects i.e. don’t modify global variables within functions. For example, classes such as GP and LuasKernel have a few attributes which are never modified by any of their methods. You should also never modify these attributes but instead should initialise a new object if changing them.

jax defaults to single-point precision, but calculations of matrix decomposition typically require double precision. Therefore, it is always recommended to include jax.config.update("jax_enable_x64", True) at the start of running a script/notebook to ensure jax uses double-precision floats.

jax.numpy closely matches NumPy in many respects but there are some differences. One difference is that changing the values of array elements has a slightly different syntax. Instead of mat[i] = value you will need to use mat = mat.at[i].set(value).

jax provides the option to compile functions with Just-in Time (JIT) compilation. This can provide significant runtime improvements as the same compilation may be used for multiple calculations but it comes with stricter requirements on how control flow is used. One particular example is that control flow cannot be based on the value of an array element but it is okay for it to depend on the shape of an array e.g. if mat[i] == 0: will fail to compile but if mat.size == 10: is fine.

jax makes it easy to calculate derivatives by using the jax.grad and jax.hessian functions provided the functions only return a Scalar value. For example, if you want to calculate the gradient of a function f(x) this can be done with jax.grad(f)(x). This will work not just if x is an array but also if x is a PyTree (which is basically just a dict), return a PyTree of gradients in the same form as x. Note that if you want to take the gradient of a function which returns an array you can do this by modifying the function so that it only returns particular elements of the array.

for loops and while loops can be very slow to compile. If possible, it is generally faster to make use of jax.vmap to “vectorise” a function. This means for example if we have a function f(x) which takes in a Scalar value x and outputs a 1D array y, we can use y_2D = jax.vmap(f, in_axes = 0)(x_1D) to create a function which takes in a 1D array x_1D and outputs a 2D array y_2D, serving a similar function as a for loop.