Source code for luas.jax_convenience_fns

import numpy as np
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from tqdm import tqdm
from copy import deepcopy
from typing import Callable, Optional, Tuple, Union
from .luas_types import PyTree, JAXArray

__all__ = [
    "get_corr_mat",
    "order_list",
    "order_dict",
    "array_to_pytree_2D",
    "pytree_to_array_2D",
    "varying_params_wrapper",
    "large_hessian_calc",
    "sigmoid",
    "transf_from_unbounded_params",
    "transf_to_unbounded_params",
]

def get_corr_mat(cov_mat: JAXArray, zero_diag: Optional[bool] = False) -> JAXArray:
    """Given a covariance matrix will return its corresponding correlation matrix.
    
    Args:
        cov_mat (JAXArray): Covariance matrix to convert
        zero_diag (bool, optional): Option to set diagonal to zeros to help visualise
            off-diagonal correlations.
        
    Returns:
        JAXArray: The input covariance matrix converted to a correlation matrix.
        
    """
    
    # Get inverse sqrt of diagonal elements
    d = jnp.diag(jnp.sqrt(jnp.reciprocal(jnp.diag(cov_mat))))
    
    # Divides each covariance element by the standard deviation of each pair of parameters
    corr_mat = d @ cov_mat @ d
    
    # Setting the diagonal to zero can be useful for visualising small correlation
    if zero_diag:
        corr_mat = corr_mat*(1 - jnp.eye(corr_mat.shape[0]))
        
    return corr_mat


def order_list(par_list: list) -> list:
    """Orders a list in the same way ``jax.flatten_util.ravel_pytree`` will order dictionary keys
    to form an array.
    
    Args:
        par_list (:obj:`list` of :obj:`str`): List to sort
        
    Returns:
        :obj:`list` of :obj:`str`: The list sorted to match jax.flatten_util.ravel_pytree.
        
    """
    
    # Create a dictionary with values ordered 0 to len(par_list) - 1
    map_dict = {par:i for (i, par) in enumerate(par_list)}
    
    # Run ravel_pytree on this dict to get the order it sorts in
    ind_order, f = ravel_pytree(map_dict)
    
    # Create a new list which sorts the old list in the way ravel_pytree flattens to arrays.
    par_ordered = [par_list[par_ind] for par_ind in ind_order]
    
    return par_ordered

    
def order_dict(par_dict: dict) -> Tuple[list, list]:
    """Takes a PyTree and returns two lists ordered to match how ``jax.flatten_util.ravel_pytree``
    would sort the keys, one list for the keys in the input dictionary and another list for the values.
    This function is useful for determining how to sort a dictionary of ``PyMC`` tensor variables into an array
    as ``jax.flatten_util.ravel_pytree`` will not work on a PyTree containing ``PyMC`` tensor variables.
    
    Args:
        par_dict (PyTree): PyTree to sort by keys
        
    Returns:
        (:obj:`list` of :obj:`str`, :obj:`list` of :obj:`str`): Returns a tuple of two lists, the first a list of keys
        sorted to match ``jax.flatten_util.ravel_pytree`` and the second a list of values also sorted to match
        ``jax.flatten_util.ravel_pytree``.
        
    """
    
    # Since the values may not be JAX variables, create a similar dict with integer values
    key_list = list(par_dict.keys())
    map_dict = {par:i for (i, par) in enumerate(key_list)}
    ind_order, f = ravel_pytree(map_dict)

    # Use the order the integers are sorted into to order the list of keys and list of values
    par_keys_ordered = [key_list[par_ind] for par_ind in ind_order]
    par_values_ordered = [par_dict[par] for par in par_keys_ordered]
    
    return par_keys_ordered, par_values_ordered


[docs] def array_to_pytree_2D(p: PyTree, arr_2D: JAXArray) -> PyTree: """Takes a 2D JAXArray (e.g. a covariance matrix) where the rows and columns are sorted according to ``jax.flatten_util.ravel_pytree`` would sort the input PyTree of parameters ``p`` and returns a nested PyTree of the array sorted into the parameter values. Args: p (PyTree): Parameters used for log likelihood calculations, used to describe the order of the array according to how ``jax.flatten_util.ravel_pytree`` will flatten into an array. arr_2D (JAXArray): A 2D array where the rows and columns are both sorted in the order in which ``jax.flatten_util.ravel_pytree`` flattens the parameter PyTree p. Returns: PyTree: A nested PyTree where the input 2D array ``arr_2D`` has been rearranged to its corresponding parameters. """ # First get the function which can convert an array into a PyTree like p in the right order p_arr, make_p_dict = ravel_pytree(p) # Use this function to sort the numbers from 0 to N for N parameters # This will be used to sort the array into a nested PyTree in the right order coord_dict = make_p_dict(jnp.arange(p_arr.size)) pytree_2D = {} # Loop through the parameters lying along each row for k1 in p.keys(): pytree_2D[k1] = {} # Loop through the parameters lying along each column for k2 in p.keys(): # Select the array elements corresponding to the row and column parameters pytree_2D[k1][k2] = arr_2D[jnp.ix_(coord_dict[k1], coord_dict[k2])] return pytree_2D
[docs] def pytree_to_array_2D( p: PyTree, pytree_2D: JAXArray, param_order: Optional[list] = None, )-> PyTree: """Inverse of ``array_to_pytree_2D``, takes a nested PyTree (e.g. describing a covariance matrix) where the keys correspond to the row and column of a 2D array with a value defined for each parameter with every other parameter. Sorts the nested PyTree into this 2D array sorted according to param_order, defaults to the order ``jax.flatten_util.ravel_pytree`` will sort dictionary keys into when forming an array. Args: p (PyTree): Parameters used for log likelihood calculations, used to describe the order of the array according to how ``jax.flatten_util.ravel_pytree`` will flatten into an array. pytree_2D (PyTree): A nested PyTree describing a 2D array keyed by the pair of parameters along each row and column. Returns: JAXArray: A 2D array which the nested PyTree has been sorted into. """ # First must solve for the order in the array to place each parameter into if param_order is None: # Defaults to sorting how ravel_pytree will sort param_order = order_list(list(p.keys())) p_arr, make_p_dict = ravel_pytree(p) coord_dict = make_p_dict(jnp.arange(p_arr.size)) # Get the total number of parameters N_par = p_arr.size else: # If a specified order has been given must solve for the locations in an array # each parameter should go to coord_dict = {} N_par = 0 for par in param_order: coord_dict[par] = jnp.arange(N_par, N_par + p[par].size) N_par += p[par].size # Loop through each element in the nested PyTree and use the array locations already found arr_2D = jnp.zeros((N_par, N_par)) for k1 in param_order: for k2 in param_order: arr_2D = arr_2D.at[jnp.ix_(coord_dict[k1], coord_dict[k2])].set(pytree_2D[k1][k2]) return arr_2D
[docs] def varying_params_wrapper( p: PyTree, vars: Optional[list] = None, fixed_vars: Optional[list] = None, to_numpy: Optional[bool] = True ) -> Tuple[PyTree, Callable]: """Often useful to take a PyTree of parameters and return a subset of the (key, value) pairs which are to be varied. E.g. ``PyMC`` requires start values at initialisation of optimisation/inference to only include parameters which are to be fit for. Also returns a function which can take the subset of parameters being varied and return the full set of parameters with the fixed parameters added back in. Note: By default will return parameter values in NumPy arrays as this is required for inference with ``PyMC``. This can be turned off by setting ``to_numpy = False``. Args: p (PyTree): The full set of parameters used for likelihood calculations potentially including both mean function parameters and hyperparameters. vars (:obj:`list` of :obj:`str`, optional): The list of keys names corresponding to the parameters being varied which we want to include in the output parameter PyTree. If specified in addition to fixed_vars will raise an ``Exception``. fixed_vars (:obj:`list` of :obj:`str`, optional): Alternative to vars, may specify instead the parameters being kept fixed which will be excluded from the output parameter PyTree. If specified in addition to vars will raise an Exception. to_numpy (bool, optional): Converts parameter values from input parameter PyTree to ``NumPy`` arrays as this is required for inference with ``PyMC``. Defaults to ``True``. Returns: (PyTree, Callable): Returns a tuple where the first element is the input PyTree ``p`` but containing only the ``(key, value)`` pairs of parameters to be varied, the second element is a function which takes the output parameter PyTree containing only the parameters being varied and adds back in the fixed parameters without overwritting the parameters being varied. """ if to_numpy: # PyMC cannot deal with values being JAXArrays so must convert to NumPy to_array = np.array else: # Do nothing to parameter values if to_numpy = False to_array = lambda p: p if vars is not None and fixed_vars is None: # If just vars specified p_vary = {par:to_array(p[par]) for par in vars} if vars is None and fixed_vars is not None: # If just fixed_vars specified p_vary = {par:to_array(p[par]) for par in p.keys() if par not in fixed_vars} elif vars is None and fixed_vars is None: # If neither vars nor fixed_vars specified just keep all parameters p_vary = {par:to_array(p[par]) for par in p.keys()} elif vars is not None and fixed_vars is not None: # If both vars and fixed vars specified raise an Exception raise Exception("Both vars and fixed_vars cannot be defined!") # Create a function which takes the subset of parameters in p_vary and # adds back in the fixed parameters p_fixed = deepcopy(p) def make_p(p_vary): p_fixed.update(p_vary) return p_fixed return p_vary, make_p
[docs] def large_hessian_calc( fn: Callable, p: PyTree, *args, block_size: Optional[int] = 50, return_array: Optional[bool] = True, jit: Optional[bool] = True, **kwargs, ) -> Union[JAXArray, PyTree]: """Breaks up the calculation of the hessian of a large matrix into groups of rows to reduce the memory cost. Useful for large data sets when ``jax.hessian`` applied to a log likelihood function can cause the code to crash. This function should work for any arbitrary function however for which ``jax`` can calculate second-order derivatives and for functions where the first argument is a PyTree which the derivative is to be calculated with respect to. Args: fn (Callable): Function to calculate the hessian of. Should be of the form ``f(p, *args, **kwargs)`` where ``p`` is a PyTree. p (PyTree): Parameters to calculate the derivative with respect to. All parameters within ``p`` will have the derivative be taken with respect to. block_size (int, optional): The number of groups of rows to calculate the second derivatives for at once. Large numbers will have a higher memory cost but may result in a shorter runtime. Defaults to 50. return_array (bool, optional): Whether to return the hessian as an array of shape ``(N_{par}, N_{par})`` for ``N_{par}`` total parameters in ``p`` or as a nested PyTree where the hessian values for parameters named par1 and par2 would be given by ``hessian_pytree[par1][par2]`` and ``hessian_pytree[par2][par1]``. Defaults to True. jit (bool, optional): Whether to JIT compile the hessian function, can speed up the calculation assuming the function can be JIT compiled. Defaults to ``True``. Returns: JAXArray or PyTree: Depending on whether ``return_array`` is set to ``True`` or ``False`` will either return a JAXArray or PyTree giving the hessian with respect to each parameter in ``p``. If returning a JAXArray then the parameters will be ordered in the same order that ``jax.flatten_util.ravel_pytree`` will order the input PyTree ``p``. """ # To break calculation into rows must first write a wrapper function which takes as input an array of parameters p_arr, make_p_dict = ravel_pytree(p) fn_arr_wrapper = lambda p_arr: fn(make_p_dict(p_arr), *args, **kwargs) # We then create a function which returns the gradient of specific array elements grad_wrapper = lambda p_arr, i: jax.grad(fn_arr_wrapper)(p_arr)[i] # Finally we create a function which takes the derivative of a group of first derivatives # vmap allows us to take this gradient wrt an array of gradients as jax.grad only works # on functions which return scalars hessian_wrapper = jax.vmap(jax.grad(grad_wrapper), in_axes = (None, 0)) if jit: hessian_wrapper = jax.jit(hessian_wrapper) N_par = p_arr.size hess_arr = jnp.zeros((N_par, N_par)) # Iterates through blocks of rows solving second derivatives for i in tqdm(range(0, N_par, block_size)): rows = jnp.arange(i, i+block_size) hess_arr = hess_arr.at[rows, :].set(hessian_wrapper(p_arr, rows)) # If the size of the array is not an even multiple of block_size will calculate the remaining rows if N_par % block_size > 0: rows = jnp.arange(i+block_size, N_par) hess_arr = hess_arr.at[rows, :].set(hessian_wrapper(p_arr, rows)) # Defaults to returning a JAXArray of parameters but can return a nested PyTree if return_array: return hess_arr else: return array_to_pytree_2D(p, hess_arr)
def sigmoid(x): """A sigmoid function. When dealing with parameters bounded between limits [a, b], PyMC and NumPyro vary a parameter between (-jnp.inf, jnp.inf) and use a sigmoid transform followed by an affine transform to map this interval to the interval (a, b). Args: x (JAXArray): A variable in the unbounded transformed space used by PyMC and NumPyro Returns: JAXArray: The sigmoid function applied to x, constraining it to the interval (0, 1) """ return 1 / (1 + jnp.exp(-x)) def transf_to_unbounded(par, bounds): """Transformation used by PyMC and NumPyro to convert a parameter which lies within the interval (bounds[0], bounds[1]) to the interval (-jnp.inf, jnp.inf). Args: p (JAXArray): A parameter value which lies in the interval (bounds[0], bounds[1]) bounds (:obj:`list` of JAXArray): A list containing the lower bound of p as its first element and the upper bound of p as its second element Returns: JAXArray: The unbounded transformed value of p used by PyMC and NumPyro when sampling. """ return jnp.log(par - bounds[0]) - jnp.log(bounds[1] - par) def transf_from_unbounded(x, bounds): """Inverse of transformation transf_to_unbounded. Converts an unbounded parameter lying within the interval (-jnp.inf, jnp.inf) to the bounded interval (bounds[0], bounds[1]). Args: x (JAXArray): A parameter value which lies in the interval (-jnp.inf, jnp.inf). bounds (:obj:`list` of JAXArray): A list containing the lower bound of the parameter as its first element and the upper bound of the parameter as its second element. Returns: JAXArray: The bounded parameter used for log likelihood calculations. """ sigmoid_x = sigmoid(x) return sigmoid_x * (bounds[1] - bounds[0]) + bounds[0]
[docs] def transf_to_unbounded_params(p, param_bounds): """Replicates the transformation used by ``PyMC`` and ``NumPyro`` to convert a PyTree of parameters - some or all of which lie within the bounds given by ``param_bounds`` - to unbounded parameters using the transformation given in ``transf_to_unbounded``. Used for calculating the Laplace approximation of the posterior with respect to the transformed parameters used by ``PyMC`` and ``NumPyro`` for sampling. Examples: For a single parameter ``p["d"]`` which lies between bounds ``(a, b)``, ``param_bounds`` should be of the form: ``param_bounds = {"d":[a, b]}`` where ``a`` and ``b`` have the same shape as ``p["d"]``. Args: p (PyTree): All parameters used for log likelihood calculations potentially including additional unbounded parameters. param_bounds (PyTree): Contains any bounds for the parameters in ``p``. Returns: JAXArray: All parameters in ``p`` with the unbounded transformed values for any parameter in ``param_bounds``. Should match the transformed parameters being sampled by ``PyMC`` and ``NumPyro``. """ # Copy any parameters which do not lie within bounds specified in param_bounds p_unbounded = deepcopy(p) # Loop through each parameter in the bounds for par in param_bounds.keys(): # Check if the parameter is actually being varied if par in p_unbounded.keys(): # Transforms to unbounded space p_unbounded[par] = transf_to_unbounded(p[par], param_bounds[par]) return p_unbounded
[docs] def transf_from_unbounded_params(p_unbounded, param_bounds): """Inverse of the transformation in transf_to_unbounded_params. Converts parameters being sampled by ``PyMC`` and ``NumPyro`` to the parameters which lie between bounds described in ``param_bounds`` which are used for log likelihood calculations. Used for calculating the Laplace approximation of the posterior with respect to the transformed parameters used by ``PyMC`` and ``NumPyro`` for sampling. Examples: For a single parameter ``p["d"]`` which should lie between bounds ``(a, b)``, ``param_bounds`` should be of the form: ``param_bounds = {"d":[a, b]}`` where ``a`` and ``b`` have the same shape as ``p["d"]``. Args: p_unbounded (PyTree): All parameters used for sampling by ``PyMC`` and ``NumPyro`` to convert for log likelihood calculations which lie in an transformed unbounded space. param_bounds (PyTree): Contains any bounds for the parameters in ``p``. Returns: JAXArray: All parameters used for log likelihood calculations potentially including additional unbounded parameters. """ # Copy any parameters which do not lie within bounds specified in param_bounds p_bounded = deepcopy(p_unbounded) # Loop through each parameter in the bounds for par in param_bounds.keys(): # Check if the parameter is actually being varied if par in p_bounded.keys(): # Transforms back to bounded space p_bounded[par] = transf_from_unbounded(p_unbounded[par], param_bounds[par]) return p_bounded