Source code for luas.LuasKernel

import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
from jax import grad, value_and_grad, hessian, vmap, custom_jvp, jit
from jax.flatten_util import ravel_pytree
from copy import deepcopy
from tqdm import tqdm
from typing import Callable, Tuple, Union, Any, Optional
from functools import partial

from .luas_types import Kernel, PyTree, JAXArray, Scalar
from .kronecker_fns import kron_prod, logdetK_calc, r_K_inv_r, K_inv_vec, logdetK_calc_hessianable
from .jax_convenience_fns import array_to_pytree_2D, get_corr_mat


__all__ = [
    "diag_eigendecomp",
    "decomp_S",
    "decomp_K_tilde",
    "LuasKernel",
]

# Ensure we are using double precision floats as JAX uses single precision by default
jax.config.update("jax_enable_x64", True)


def diag_eigendecomp(K: JAXArray) -> Tuple[JAXArray, JAXArray]:
    r"""Calculates the eigenvalues and eigenvectors of a matrix ``K`` assuming it is diagonal.
    Can be significantly faster than ``jax.numpy.linalg.eigh`` if a matrix is known in advance to be diagonal.
    
    Args:
        K (JAXArray): The matrix to decompose
        
    Returns:
        (JAXArray, JAXArray): Returns a tuple of the eigenvalues and eigenvector matrix of ``K``.
    
    """
    return jnp.diag(K), jnp.eye(K.shape[0])


def decomp_S(
    S: JAXArray,
    eigen_fn: Optional[Callable] = jnp.linalg.eigh
) -> Tuple[JAXArray, JAXArray]:
    r"""Calculates the eigenvalues and matrix inverse square root of a matrix ``S``. Needs to be performed
    on the ``Sl`` and ``St`` matrices for log likelihood calculations with the :class:`LuasKernel`.
    
    Args:
        S: The matrix to decompose.
        eigen_fn: The function to use to solve for the eigenvalues and eigenvectors of ``S``.
        
    Returns:
        (JAXArray, JAXArray): Returns a tuple of the eigenvalues as well as the matrix inverse square root
        of ``S``.
    
    """
    
    # Solve for eigenvalues and eigenvectors
    lam_S, Q_S = eigen_fn(S)
    
    # Calculate the matrix inverse square root
    S_inv_sqrt = Q_S @ jnp.diag(jnp.sqrt(jnp.reciprocal(lam_S)))

    return lam_S, S_inv_sqrt


def decomp_K_tilde(
    K: JAXArray,
    S_inv_sqrt: JAXArray,
    eigen_fn: Optional[Callable] = jnp.linalg.eigh
) -> Tuple[JAXArray, JAXArray]:
    r"""Creates the transformed matrix ``K_tilde`` from ``K`` and the matrix inverse square root of ``S``. 
    Then calculates the eigendecomposition of ``K_tilde``. Finally calculates the required ``W`` matrices.
    Returns the eigenvalues of ``K_tilde`` and the calculated ``W`` matrix. Needs to be separately performed on the
    ``Kl`` and ``Kt`` matrices for log likelihood calculations with the :class:`LuasKernel`.
    
    For ``K = Kl`` this function will solve for ``Kl_tilde`` using:
    
    .. math::
    
        \tilde{K}_l = (S_l^{-\frac{1}{2}})^T K_l S_l^{-\frac{1}{2}}
    
    Then eigendecompose ``Kl_tilde`` and calculate the required ``W_l`` matrix with:
    
    .. math::
        W_l = S_l^{-\frac{1}{2}} Q_{\tilde{K}_l}
    
    and similarly for ``K = Kt``.
    
    Args:
        S: The matrix to decompose.
        eigen_fn: The function to use to solve for the eigenvalues and eigenvectors of ``S``.
        
    Returns:
        (JAXArray, JAXArray): Returns a tuple of the eigenvalues and eigenvectors of the transformed matrix
        ``K_tilde``.
    
    """
    
    # Solve K_tilde
    K_tilde = S_inv_sqrt.T @ K @ S_inv_sqrt
    
    # Eigendecompose K_tilde
    lam_K_tilde, Q_K_tilde = eigen_fn(K_tilde)
    
    # Calculate the required W matrix
    W = S_inv_sqrt @ Q_K_tilde

    return lam_K_tilde, W


[docs] class LuasKernel(Kernel): r"""Kernel class which solves for the log likelihood for any covariance matrix which is the sum of two kronecker products of the covariance matrix in each of two dimensions i.e. the full covariance matrix K is given by: .. math:: K = K_l \otimes K_t + S_l \otimes S_t although we can avoid calculating ``K`` for many calculations implemented here. The ``Kl`` and ``Sl`` functions should both return ``(N_l, N_l)`` matrices which will be the covariance matrices in the wavelength/vertical direction. The ``Kt`` and ``St`` functions should both return ``(N_t, N_t)`` matrices which will by the covariance matrices in the time/horizontal direction. .. code-block:: python >>> from luas import LuasKernel, kernels >>> def Kl_fn(hp, x_l1, x_l2, wn = True): >>> ... return hp["h"]**2*kernels.squared_exp(x_l1, x_l2, hp["l_l"]) >>> def Kt_fn(hp, x_t1, x_t2, wn = True): >>> ... return kernels.squared_exp(x_t1, x_t2, hp["l_t"]) >>> # ... And similarly for Sl_fn, St_fn >>> kernel = LuasKernel(Kl = Kl_fn, Kt = Kt_fn, Sl = Sl_fn, St = St_fn) ... ) See https://luas.readthedocs.io/en/latest/tutorials.html for more detailed tutorials on how to use. Args: Kl (Callable): Function which returns the covariance matrix Kl, should be of the form ``Kl(hp, x_l1, x_l2, wn = True)``. Kt (Callable): Function which returns the covariance matrix Kt, should be of the form ``Kt(hp, x_t1, x_t2, wn = True)``. Sl (Callable): Function which returns the covariance matrix Sl, should be of the form ``Sl(hp, x_l1, x_l2, wn = True)``. St (Callable): Function which returns the covariance matrix St, should be of the form ``St(hp, x_t1, x_t2, wn = True)``. use_stored_values (bool, optional): Whether to perform checks if any of the component covariance matrices have changed and to make use of previously stored values for the decomposition of those matrices if they're the same. If ``False`` then will not perform these checks and will compute the eigendecomposition of all matrices for every calculation. """ def __init__( self, Kl: Callable = None, Kt: Callable = None, Sl: Callable = None, St: Callable = None, use_stored_values: Optional[bool] = True, ): self.Kl = Kl self.Kt = Kt self.Sl = Sl self.St = St # Have different decomposition functions depending on whether previous stored values # are to be used to avoid recalculating eigendecompositions if use_stored_values: self.decomp_fn = self.eigendecomp_use_stored_values else: self.decomp_fn = self.eigendecomp_no_stored_values # Identify how to eigendecompose each of the matrices # Note that for Kl and Kt it will actually be the transformed matrices # Kl_tilde and Kt_tilde being eigendecomposed for fn in [self.Sl, self.St, self.Kl, self.Kt]: if hasattr(fn, "decomp"): if fn.decomp == "diag": fn.decomp = diag_eigendecomp else: fn.decomp = jnp.linalg.eigh if self.Kl.decomp == diag_eigendecomp and not self.Sl.decomp == diag_eigendecomp: print("""NOTE: The transformation of Kl is set to be diagonal but the matrix Sl is not set to diagonal.This may be possible for example if Kl is a scalar times the identity matrix or Kl shares the same eigenvectors as Sl but it is not true if Kl is any general diagonal matrix. Alternatively perhaps Sl is also diagonal and you forgot to add Sl.decomp = 'diag'. Be careful to ensure the transformation of Kl is diagonal or else log likelihood values will be incorrect!""") if self.Kt.decomp == diag_eigendecomp and not self.St.decomp == diag_eigendecomp: print("""NOTE: The transformation of Kt is set to be diagonal but the matrix St is not set to diagonal. This may be possible for example if Kt is a scalar times the identity matrix or Kt shares the same eigenvectors as St but it is not true if Kt is any general diagonal matrix. Alternatively perhaps St is also diagonal and you forgot to add St.decomp = 'diag'. Be careful to ensure the transformation of Kt is diagonal or else log likelihood values will be incorrect!""") # Specify how each of the 4 matrices will be decomposed self.Sl_decomp_fn = lambda Sl: decomp_S(Sl, eigen_fn = self.Sl.decomp) self.St_decomp_fn = lambda St: decomp_S(St, eigen_fn = self.St.decomp) self.Kl_tilde_decomp_fn = lambda Kl, Sl_inv_sqrt: decomp_K_tilde(Kl, Sl_inv_sqrt, eigen_fn = self.Kl.decomp) self.Kt_tilde_decomp_fn = lambda Kt, St_inv_sqrt: decomp_K_tilde(Kt, St_inv_sqrt, eigen_fn = self.Kt.decomp)
[docs] def eigendecomp_no_stored_values( self, hp: PyTree, x_l: JAXArray, x_t: JAXArray, stored_values: Optional[PyTree] = {}, ) -> PyTree: r"""Required calculations for the decomposition of the overall matrix ``K`` where the previously stored decomposition of ``K`` cannot be used for the calculation of a new decomposition. This avoids checking if any of the matrices have changed but may result in performing the same eigendecomposition calculations multiple times. We can decompose the inverse of ``K`` into the matrices: .. math:: K^{-1} = [W_l \otimes W_t] D^{-1} [W_l^T \otimes W_t^T] Where this function will calculate ``W_l``, ``W_t`` and ``D_inv`` and stored them in the ``stored_values`` PyTree for future log likelihood calculations. Note: Values still need to be stored for any log likelihood calculations so this method does not save memory over ``eigendecomp_use_stored_values``. It may however reduce runtimes by avoiding checking if matrices have changed so it could be beneficial if all hyperparameters are being varied simultaneously for each calculation. Args: hp (Pytree): Hyperparameters needed to build the covariance matrices ``Kl``, ``Kt``, ``Sl``, ``St``. Will be unaffected if additional mean function parameters are also included. x_l (JAXArray): Array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape ``(N_l,)`` or ``(d_l,N_l)`` for ``d_l`` different wavelength/vertical regression variables. x_t (JAXArray): Array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape ``(N_t,)`` or ``(d_t,N_t)`` for ``d_t`` different time/horizontal regression variables. stored_values (PyTree): This may contain stored values from the decomposition of ``K`` but this method will not make use of it. This dictionary will simply be overwritten with new stored values from the decomposition of ``K``. Returns: PyTree: Stored values from the decomposition of the covariance matrices. For :class:`LuasKernel` this consists of values computed using the eigendecomposition of each matrix and also the log determinant of ``K``. """ # Calculate each of the four component matrices and decompose them into the required matrices for log likelihood calculations stored_values["Sl"] = self.Sl(hp, x_l, x_l) stored_values["lam_Sl"], stored_values["Sl_inv_sqrt"] = self.Sl_decomp_fn(stored_values["Sl"]) stored_values["St"] = self.St(hp, x_t, x_t) stored_values["lam_St"], stored_values["St_inv_sqrt"] = self.St_decomp_fn(stored_values["St"]) # See decomp_K_tilde for how W_l and the eigenvalues of Kl_tilde are calculated stored_values["Kl"] = self.Kl(hp, x_l, x_l) stored_values["lam_Kl_tilde"], stored_values["W_l"] = self.Kl_tilde_decomp_fn(stored_values["Kl"], stored_values["Sl_inv_sqrt"]) stored_values["Kt"] = self.Kt(hp, x_t, x_t) stored_values["lam_Kt_tilde"], stored_values["W_t"] = self.Kt_tilde_decomp_fn(stored_values["Kt"], stored_values["St_inv_sqrt"]) # D is needed for calculation the log determinant of K D = jnp.outer(stored_values["lam_Kl_tilde"], stored_values["lam_Kt_tilde"]) + 1. # D^-1 is needed for calculating K^-1 r stored_values["D_inv"] = jnp.reciprocal(D) # Computes the log determinant of K lam_S = jnp.outer(stored_values["lam_Sl"], stored_values["lam_St"]) stored_values["logdetK"] = jnp.log(jnp.multiply(D, lam_S)).sum() return stored_values
[docs] def eigendecomp_use_stored_values( self, hp: PyTree, x_l: JAXArray, x_t: JAXArray, stored_values: Optional[PyTree] = {}, rtol: Optional[Scalar] = 1e-12, atol: Optional[Scalar] = 1e-12, ) -> PyTree: r"""Required calculations for the decomposition of the overall matrix ``K`` where the previously stored decomposition of ``K`` may be used for the calculation of a new decomposition. This checking if any of the matrices have changed and if they are similar within the given tolerances a previously computed eigendecomposition can be used to avoid recalculating it. This can provide significant runtime savings if some hyperparameters are being kept fixed including if blocked Gibbs sampling is being used on groups of hyperparameters. We can decompose the inverse of ``K`` into the matrices: .. math:: K^{-1} = [W_l \otimes W_t] D^{-1} [W_l^T \otimes W_t^T] Where this function will calculate ``W_l``, ``W_t`` and ``D_inv`` and stored them in the stored_values PyTree for future log likelihood calculations. Args: hp (Pytree): Hyperparameters needed to build the covariance matrices ``Kl``, ``Kt``, ``Sl``, ``St``. Will be unaffected if additional mean function parameters are also included. x_l (JAXArray): Array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape ``(N_l,)`` or ``(d_l,N_l)`` for ``d_l`` different wavelength/vertical regression variables. x_t (JAXArray): Array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape ``(N_t,)`` or ``(d_t,N_t)`` for ``d_t`` different time/horizontal regression variables. stored_values (PyTree): Stored values from the decomposition of the covariance matrices. For :class:`LuasKernel` this consists of values computed using the eigendecomposition of each matrix and also the log determinant of ``K``. rtol (Scalar): The relative tolerance value any of the component covariance matrices must be within in order for the matrix to be considered unchanged and stored values for its decomposition to be used. atol (Scalar): The absolute tolerance values any of the component covariance matrices must be within in order for the matrix to be considered unchanged and stored values for its decomposition to be used. Returns: PyTree: Stored values from the decomposition of the covariance matrices. For :class:`LuasKernel` this consists of values computed using the eigendecomposition of each matrix and also the log determinant of ``K``. """ stored_values = deepcopy(stored_values) # Calculate each of the four component matrices Sl = self.Sl(hp, x_l, x_l) St = self.St(hp, x_t, x_t) Kl = self.Kl(hp, x_l, x_l) Kt = self.Kt(hp, x_t, x_t) if stored_values: # Check if any of the 4 component matrices have changed from their values in stored_values # Note JAX requires the two possible outputs of the conditional to be functions # so we use functions which just return True or False Sl_diff = jax.lax.cond(jnp.allclose(Sl, stored_values["Sl"], rtol = rtol, atol = atol), lambda : False, lambda : True) St_diff = jax.lax.cond(jnp.allclose(St, stored_values["St"], rtol = rtol, atol = atol), lambda : False, lambda : True) # Note that if Sl is different than Kl_tilde is also almost certainly different # so even if Kl hasn't changed we still need to recompute the decomposition of Kl_tilde and similarly for Kt Kl_tilde_diff = jax.lax.cond(jnp.allclose(Kl, stored_values["Kl"], rtol = rtol, atol = atol), lambda : Sl_diff, lambda : True) Kt_tilde_diff = jax.lax.cond(jnp.allclose(Kt, stored_values["Kt"], rtol = rtol, atol = atol), lambda : St_diff, lambda : True) else: Sl_diff = St_diff = Kl_tilde_diff = Kt_tilde_diff = True N_l = x_l.shape[-1] N_t = x_t.shape[-1] # JAX requires that the two outputs of any conditional statements have the same shape # so must define matrices of same shape as their actual values even though they will be overwritten stored_values["lam_Sl"] = jnp.zeros(N_l) stored_values["Sl_inv_sqrt"] = jnp.zeros((N_l, N_l)) stored_values["lam_St"] = jnp.zeros(N_t) stored_values["St_inv_sqrt"] = jnp.zeros((N_t, N_t)) stored_values["lam_Kl_tilde"] = jnp.zeros(N_l) stored_values["W_l"] = jnp.zeros((N_l, N_l)) stored_values["lam_Kt_tilde"] = jnp.zeros(N_t) stored_values["W_t"] = jnp.zeros((N_t, N_t)) # For each of the 4 component matrices conditionally decompose them if they have changed since the last calculation stored_values["lam_Sl"], stored_values["Sl_inv_sqrt"] = jax.lax.cond(Sl_diff, self.Sl_decomp_fn, lambda *args: (stored_values["lam_Sl"], stored_values["Sl_inv_sqrt"]), Sl) stored_values["lam_St"], stored_values["St_inv_sqrt"] = jax.lax.cond(St_diff, self.St_decomp_fn, lambda *args: (stored_values["lam_St"], stored_values["St_inv_sqrt"]), St) stored_values["lam_Kl_tilde"], stored_values["W_l"] = jax.lax.cond(Kl_tilde_diff, self.Kl_tilde_decomp_fn, lambda *args: (stored_values["lam_Kl_tilde"], stored_values["W_l"]), Kl, stored_values["Sl_inv_sqrt"]) stored_values["lam_Kt_tilde"], stored_values["W_t"] = jax.lax.cond(Kt_tilde_diff, self.Kt_tilde_decomp_fn, lambda *args: (stored_values["lam_Kt_tilde"], stored_values["W_t"]), Kt, stored_values["St_inv_sqrt"]) # D is needed for calculation the log determinant of K D = jnp.outer(stored_values["lam_Kl_tilde"], stored_values["lam_Kt_tilde"]) + 1. # D^-1 is needed for calculating K^-1 r stored_values["D_inv"] = jnp.reciprocal(D) # Computes the log determinant of K lam_S = jnp.outer(stored_values["lam_Sl"], stored_values["lam_St"]) stored_values["logdetK"] = jnp.log(jnp.multiply(D, lam_S)).sum() # Store in order to perform checks for the next call of this function stored_values["Sl"] = Sl stored_values["St"] = St stored_values["Kl"] = Kl stored_values["Kt"] = Kt return stored_values
[docs] def logL( self, hp: PyTree, x_l: JAXArray, x_t: JAXArray, R: JAXArray, stored_values: PyTree, ) -> Tuple[Scalar, PyTree]: """Computes the log likelihood using the method originally presented in Rakitsch et al. (2013) and also outlined in Fortune at al. (2024). Also returns stored values from the matrix decomposition. Note: Calculating the hessian of this function with ``jax.hessian`` may not produce numerically stable results. ``LuasKernel.logL_hessianable`` is recommended is values of the hessian are needed. This method typically outperforms ``LuasKernel.logL_hessianable`` in runtime for gradient calculations however. Args: hp (Pytree): Hyperparameters needed to build the covariance matrices ``Kl``, ``Kt``, ``Sl``, ``St``. Will be unaffected if additional mean function parameters are also included. x_l (JAXArray): Array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape ``(N_l,)`` or ``(d_l,N_l)`` for ``d_l`` different wavelength/vertical regression variables. x_t (JAXArray): Array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape ``(N_t,)`` or ``(d_t,N_t)`` for ``d_t`` different time/horizontal regression variables. R (JAXArray): Residuals to be fit calculated from the observed data by subtracting the deterministic mean function. Must have the same shape as the observed data (N_l, N_t). stored_values (PyTree): Stored values from the decomposition of the covariance matrices. For :class:`LuasKernel` this consists of values computed using the eigendecomposition of each matrix and also the log determinant of ``K``. Returns: (Scalar, PyTree): A tuple where the first element is the value of the log likelihood. The second element is a PyTree which contains stored values from the decomposition of the covariance matrix. """ # Calculate the decomposition of K stored_values = self.decomp_fn(hp, x_l, x_t, stored_values = stored_values) # Use functions with custom derivatives to accurately calculate the log # likelihood and its gradient rKr = r_K_inv_r(R, stored_values) logdetK = logdetK_calc(stored_values) logL = - 0.5 * logdetK - 0.5 * R.size * jnp.log(2*jnp.pi) -0.5 * rKr return logL, stored_values
[docs] def logL_hessianable( self, hp: PyTree, x_l: JAXArray, x_t: JAXArray, R: JAXArray, stored_values: PyTree, ) -> Tuple[Scalar, PyTree]: """Computes the log likelihood using the method originally presented in Rakitsch et al. (2013) and also outlined in Fortune at al. (2024). Note: The hessian of this log likelihood function can be calculated using ``jax.hessian`` and should be more numerically stable for this than ``LuasKernel.logL``. However, this function is slower for calculating the gradients of the log likelihood so ``LuasKernel.logL`` is preferred unless the hessian is needed. Also returns stored values from the matrix decomposition. Args: hp (Pytree): Hyperparameters needed to build the covariance matrices ``Kl``, ``Kt``, ``Sl``, ``St``. Will be unaffected if additional mean function parameters are also included. x_l (JAXArray): Array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape ``(N_l,)`` or ``(d_l,N_l)`` for ``d_l`` different wavelength/vertical regression variables. x_t (JAXArray): Array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape ``(N_t,)`` or ``(d_t,N_t)`` for ``d_t`` different time/horizontal regression variables. R (JAXArray): Residuals to be fit calculated from the observed data by subtracting the deterministic mean function. Must have the same shape as the observed data (N_l, N_t). stored_values (PyTree): Stored values from the decomposition of the covariance matrices. For :class:`LuasKernel` this consists of values computed using the eigendecomposition of each matrix and also the log determinant of ``K``. Returns: (Scalar, PyTree): A tuple where the first element is the value of the log likelihood. The second element is a PyTree which contains stored values from the decomposition of the covariance matrix. """ # Calculate the decomposition of K stored_values = self.decomp_fn(hp, x_l, x_t, stored_values = stored_values) # Use functions with custom derivatives to accurately calculate the log # likelihood, its gradient and hessian rKr = r_K_inv_r(R, stored_values) logdetK = logdetK_calc_hessianable(stored_values) logL = -0.5 * rKr - 0.5 * logdetK - 0.5 * R.size * jnp.log(2*jnp.pi) return logL, stored_values
[docs] def predict( self, hp: PyTree, x_l: JAXArray, x_l_pred: JAXArray, x_t: JAXArray, x_t_pred: JAXArray, R: JAXArray, M_s: JAXArray, wn = True, return_std_dev = True, ) -> Tuple[JAXArray, JAXArray]: r"""Performs GP regression and computes the GP predictive mean and the GP predictive uncertainty as the standard devation at each location or else can return the full covariance matrix. Requires the input kernel function ``K`` to have a ``wn`` keyword argument that defines the kernel when white noise is included (``wn = True``) and when white noise isn't included (``wn = False``). Currently assumes the same input hyperparameters for both the observed and predicted locations. The predicted locations ``x_l_pred`` and ``x_t_pred`` may deviate from the observed locations ``x_l`` and ``x_t`` however. The GP predictive mean is defined as: .. math:: \mathbb{E}[\vec{y}_*] = \vec{\mu}_* + \mathbf{K}_*^T \mathbf{K}^{-1} \vec{r} And the GP predictive covariance is given by: .. math:: Var[\vec{y}_*] = \mathbf{K}_{**} - \mathbf{K}_*^T \mathbf{K}^{-1} \mathbf{K}_* Note: The calculation of the full predictive covariance matrix when ``return_std_dev = False`` is still experimental and may come with numerically stability issues. It is also very memory intensive and may cause code to crash. Future updates to luas may improve this. Args: hp (Pytree): Hyperparameters needed to build the covariance matrices ``Kl``, ``Kt``, ``Sl``, ``St``. Will be unaffected if additional mean function parameters are also included. x_l (JAXArray): Array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape ``(N_l,)`` or ``(d_l,N_l)`` for ``d_l`` different wavelength/vertical regression variables. x_l_pred (JAXArray): Array containing wavelength/vertical dimension regression variable(s) for the prediction locations (which may be the same as the observed locations). May be of shape ``(N_l_pred,)`` or ``(d_l,N_l_pred)`` for ``d_l`` different wavelength/vertical regression variables. x_t (JAXArray): Array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape ``(N_t,)`` or ``(d_t,N_t)`` for ``d_t`` different time/horizontal regression variables. x_t_pred (JAXArray): Array containing time/horizontal dimension regression variable(s) for the prediction locations (which may be the same as the observed locations). May be of shape ``(N_t,)`` or ``(d_t,N_t)`` for ``d_t`` different time/horizontal regression variables. R (JAXArray): Residuals to be fit, equal to the observed data minus the deterministic mean function. Must have the same shape as the observed data ``(N_l, N_t)``. M_s (JAXArray): Mean function evaluated at the locations of the predictions ``x_l_pred``, ``x_t_pred``. Must have shape ``(N_l_pred, N_t_pred)`` where ``N_l_pred`` is the number of wavelength/vertical dimension predictions and ``N_t_pred`` the number of time/horizontal dimension predictions. wn (bool, optional): Whether to include white noise in the uncertainty at the predicted locations. Defaults to True. return_std_dev (bool, optional): If ``True`` will return the standard deviation of uncertainty at the predicted locations. Otherwise will return the full predictive covariance matrix. Defaults to True. Returns: (JAXArray, JAXArray): Returns a tuple of two elements, where the first element is the GP predictive mean at the prediction locations, the second element is either the standard deviation of the predictions if ``return_std_dev = True``, otherwise it will be the full covariance matrix of the predicted values. """ # Calculate the decomposition of K stored_values = self.decomp_fn(hp, x_l, x_t) # Calculate the covariance between the observed and predicted points Kl_s = self.Kl(hp, x_l, x_l_pred, wn = False) Kt_s = self.Kt(hp, x_t, x_t_pred, wn = False) Sl_s = self.Sl(hp, x_l, x_l_pred, wn = False) St_s = self.St(hp, x_t, x_t_pred, wn = False) # Calculate the covariance between predicted points with other predicted points Kl_ss = self.Kl(hp, x_l_pred, x_l_pred, wn = wn) Kt_ss = self.Kt(hp, x_t_pred, x_t_pred, wn = wn) Sl_ss = self.Sl(hp, x_l_pred, x_l_pred, wn = wn) St_ss = self.St(hp, x_t_pred, x_t_pred, wn = wn) # Calculate K^-1 R K_inv_R = K_inv_vec(R, stored_values) # Calculates the GP mean including the deterministic mean function at the prediction locations gp_mean = M_s + kron_prod(Kl_s.T, Kt_s.T, K_inv_R) + kron_prod(Sl_s.T, St_s.T, K_inv_R) # Prepare matrices for calculating the predictive covariance KW_l = Kl_s.T @ stored_values["W_l"] KW_t = Kt_s.T @ stored_values["W_t"] SW_l = Sl_s.T @ stored_values["W_l"] SW_t = St_s.T @ stored_values["W_t"] if return_std_dev: # Efficiently solves for the diagonal of the predictive covariance pred_err = jnp.outer(jnp.diag(Kl_ss), jnp.diag(Kt_ss)) pred_err += jnp.outer(jnp.diag(Sl_ss), jnp.diag(St_ss)) # K_s.T K^-1 K_s term can be broken into these three terms pred_err -= kron_prod(KW_l**2, KW_t**2, stored_values["D_inv"]) pred_err -= kron_prod(SW_l**2, SW_t**2, stored_values["D_inv"]) pred_err -= 2*kron_prod(KW_l * SW_l, KW_t * SW_t, stored_values["D_inv"]) # Take the sqrt of the diagonal to get the std dev pred_err = jnp.sqrt(pred_err) else: # Note very memory intensive! K_W = jnp.kron(KW_l, KW_t) + jnp.kron(SW_l, SW_t) pred_err = -K_W @ jnp.diag(stored_values["D_inv"].ravel()) @ K_W.T # Add the K_ss term pred_err += jnp.kron(Kl_ss, Kt_ss) + jnp.kron(Sl_ss, St_ss) return gp_mean, pred_err
[docs] def generate_noise( self, hp: PyTree, x_l: JAXArray, x_t: JAXArray, size: Optional[int] = 1, wn: Optional[bool] = True, ) -> JAXArray: r"""Generate noise with the covariance matrix returned by this kernel using the input hyperparameters ``hp``. Solves for the matrix square root of K and then multiplies this by a random normal vector. Doing it this way has numerical stability advantages over generating noise separately for each of the two kronecker products of K as they might not both be well-conditioned matrices. Args: hp (Pytree): Hyperparameters needed to build the covariance matrices ``Kl``, ``Kt``, ``Sl``, ``St``. Will be unaffected if additional mean function parameters are also included. x_l (JAXArray): Array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape ``(N_l,)`` or ``(d_l,N_l)`` for ``d_l`` different wavelength/vertical regression variables. x_t (JAXArray): Array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape ``(N_t,)`` or ``(d_t,N_t)`` for ``d_t`` different time/horizontal regression variables. size (int, optional): The number of different draws of noise to generate. Defaults to 1. wn (bool, optional): Whether to include white noise when generating noise. Must have a `wn` keyword argument in all kernel functions ``Kl``, ``Kt``, ``Sl``, ``St``. Returns: JAXArray: If ``size = 1`` will generate noise of shape ``(N_l, N_t)``, otherwise if ``size > 1`` then generated noise will be of shape ``(N_l, N_t, size)``. """ N_l = x_l.shape[-1] N_t = x_t.shape[-1] # Solve for the matrix sqrt and matrix inv sqrt for Sl and St Sl = self.Sl(hp, x_l, x_l, wn = wn) lam_Sl, Q_Sl = self.Sl.decomp(Sl) Sl_sqrt = Q_Sl @ jnp.diag(jnp.sqrt(lam_Sl)) Sl_inv_sqrt = Q_Sl @ jnp.diag(jnp.sqrt(jnp.reciprocal(lam_Sl))) St = self.St(hp, x_t, x_t) lam_St, Q_St = self.St.decomp(St) St_sqrt = Q_St @ jnp.diag(jnp.sqrt(lam_St)) St_inv_sqrt = Q_St @ jnp.diag(jnp.sqrt(jnp.reciprocal(lam_St))) # Solve for the eigenvalues and eigenvectors of Kl_tilde, Kt_tilde Kl = self.Kl(hp, x_l, x_l) Kl_tilde = Sl_inv_sqrt.T @ Kl @ Sl_inv_sqrt lam_Kl_tilde, Q_Kl_tilde = self.Kl.decomp(Kl_tilde) Kt = self.Kt(hp, x_t, x_t) Kt_tilde = St_inv_sqrt.T @ Kt @ St_inv_sqrt lam_Kt_tilde, Q_Kt_tilde = self.Kt.decomp(Kt_tilde) # Computes the sqrt of the diagonal matrix D D_half = jnp.sqrt(jnp.outer(lam_Kl_tilde, lam_Kt_tilde) + 1.) D_half = D_half.reshape((N_l, N_t, 1)) # vmap kron_prod so that it will work for z of shape (N_l, N_t, size) kron_prod_vmap = jax.vmap(kron_prod, in_axes = (None, None, 2), out_axes = 2) # Generate random normal vector z = np.random.normal(size = (N_l, N_t, size)) # Multiply by the matrix sqrt of K z = jnp.multiply(D_half, z) z = kron_prod_vmap(Q_Kl_tilde, Q_Kt_tilde, z) R = kron_prod_vmap(Sl_sqrt, St_sqrt, z) # If size = 1 then return as shape (N_l, N_t) instead of (N_l, N_t, 1) if size == 1: R = R.reshape((N_l, N_t)) return R
[docs] def K( self, hp: PyTree, x_l1: JAXArray, x_l2: JAXArray, x_t1: JAXArray, x_t2: JAXArray, **kwargs, ) -> JAXArray: r"""Generates the full covariance matrix K formed from the sum of two kronecker products: .. math:: K = K_l \otimes K_t + S_l \otimes S_t Not needed for any calculations with the ``LuasKernel`` but useful for creating a :class:`GeneralKernel` object with the same kernel function as a :class:`LuasKernel`. Args: hp (Pytree): Hyperparameters needed to build the covariance matrices ``Kl``, ``Kt``, ``Sl``, ``St``. Will be unaffected if additional mean function parameters are also included. x_l1 (JAXArray): The first array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape ``(N_l,)`` or ``(d_l,N_l)`` for ``d_l`` different wavelength/vertical regression variables. x_l2 (JAXArray): Second array containing wavelength/vertical dimension regression variable(s). x_t1 (JAXArray): The first array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape ``(N_t,)`` or ``(d_t,N_t)`` for ``d_t`` different time/horizontal regression variables. x_t2 (JAXArray): Second array containing time/horizontal dimension regression variable(s). Returns: JAXArray: The full covariance matrix K of shape ``(N_l*N_t, N_l*N_t)``. """ # Build 4 component matrices Kl = self.Kl(hp, x_l1, x_l2, **kwargs) Kt = self.Kt(hp, x_t1, x_t2, **kwargs) Sl = self.Sl(hp, x_l1, x_l2, **kwargs) St = self.St(hp, x_t1, x_t2, **kwargs) K = jnp.kron(Kl, Kt) + jnp.kron(Sl, St) return K
[docs] def visualise_covariance_in_data( self, hp: PyTree, x_l: JAXArray, x_t: JAXArray, i: int, j: int, corr: Optional[bool] = False, wn: Optional[bool] = True, x_l_plot: Optional[JAXArray] = None, x_t_plot: Optional[JAXArray] = None, **kwargs, ) -> plt.Figure: r"""Creates a plot to aid in visualising how the kernel function is defining the covariance between different points in the observed data. Calculates the covariance of each point in the observed data with a point located at ``(i, j)`` in the observed data. The plot then displays this covariance using ``plt.pcolormesh`` with every other point in the observed data. If ``corr = True`` this will display the correlation instead of the covariance. Also if ``wn = False`` then white noise will be excluded from the calculation of the covariance/correlation between each point. This can be helpful if the white noise has a much larger amplitude than correlated noise which can make it difficult to visualise how points are correlated. Args: hp (Pytree): Hyperparameters needed to build the covariance matrices ``Kl``, ``Kt``, ``Sl``, ``St``. Will be unaffected if additional mean function parameters are also included. x_l (JAXArray): Array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape ``(N_l,)`` or ``(d_l,N_l)`` for ``d_l`` different wavelength/vertical regression variables. x_t (JAXArray): Array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape ``(N_t,)`` or ``(d_t,N_t)`` for ``d_t`` different time/horizontal regression variables. i (int): The wavelength/vertical location of the point to visualise covariance with. j (int): The time/horizontal location of the point to visualise covariance with. corr (bool, optional): If ``True`` will plot the correlation between points instead of the covariance. Defaults to ``False``. wn (bool, optional): Whether to include white noise in the calculation of covariance. Defaults to ``True``. x_l_plot (JAXArray, optional): The values on the y-axis used by ``plt.pcolormesh`` for the plot. If not included will default to ``x_l`` if ``x_l`` is of shape ``(N_l,)`` or to ``x_l[0, :]`` if ``x_l`` is of shape ``(d_l, N_l)``. x_t_plot (JAXArray, optional): The values on the x-axis used by ``plt.pcolormesh`` for the plot. If not included will default to ``x_t`` if ``x_t`` is of shape ``(N_t,)`` or to ``x_t[0, :]`` if ``x_t`` is of shape ``(d_t, N_t)``. Returns: plt.Figure: A figure displaying the covariance of each point in the observed data with the selected point located at ``(i, j)`` in the observed data ``Y``. """ # If no x and y axes for the plots specified, defaults to x_l, x_t # If x_l or x_t contain multiple rows then pick the first row if x_l_plot is None: if x_l.ndim == 1: x_l_plot = x_l else: x_l_plot = x_l[0, :] if x_t_plot is None: if x_t.ndim == 1: x_t_plot = x_t else: x_t_plot = x_t[0, :] # Build 4 component matrices Kl = self.Kl(hp, x_l1, x_l2, wn = wn) Kt = self.Kt(hp, x_t1, x_t2, wn = wn) Sl = self.Sl(hp, x_l1, x_l2, wn = wn) St = self.St(hp, x_t1, x_t2, wn = wn) # Calculate covariance wrt point at (i, j) Kl_i = Kl[i, :] Kt_j = Kt[j, :] Sl_i = Sl[i, :] St_j = St[j, :] # Calculate covariance with the same shape as the observed data Y cov = jnp.outer(Kl_i, Kt_j) + jnp.outer(Sl_i, St_j) if corr: # If calculating the correlation matrix must divide by the standard deviation along # each row and column of the covariance matrix Kl_diag = jnp.diag(Kl) Kt_diag = jnp.diag(Kt) Sl_diag = jnp.diag(Sl) St_diag = jnp.diag(St) cov /= jnp.sqrt(cov[i, j]*(jnp.outer(Kl_diag, Kt_diag) + jnp.outer(Sl_diag, St_diag))) # Generate plot as a pcolormesh fig = plt.pcolormesh(x_t_plot, x_l_plot, cov, **kwargs) plt.gca().invert_yaxis() plt.xlabel(r"$x_t$") plt.ylabel(r"$x_l$") plt.colorbar() return fig
[docs] def visualise_covariance_matrix( self, hp: PyTree, x_l: JAXArray, x_t: JAXArray, corr: Optional[bool] = False, wn: Optional[bool] = True, x_l_plot: Optional[JAXArray] = None, x_t_plot: Optional[JAXArray] = None, full: Optional[bool] = False, ) -> plt.Figure: r"""Visualise the covariance matrix/matrices generated by the input hyperparameters. Note: Default behaviour is to separately visualise each of the 4 component covariance matrices ``Kl``, ``Kt``, ``Sl``, ``St`` which are used to calculate the full covariance matrix ``K``. If ``full = True`` then will instead build the full covariance matrix ``K`` but this is very memory intensive as it requires creating a JAXArray with ``(N_l*N_t, N_l*N_t)`` entries. Args: hp (Pytree): Hyperparameters needed to build the covariance matrices ``Kl``, ``Kt``, ``Sl``, ``St``. Will be unaffected if additional mean function parameters are also included. x_l (JAXArray): Array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape ``(N_l,)`` or ``(d_l,N_l)`` for ``d_l`` different wavelength/vertical regression variables. x_t (JAXArray): Array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape ``(N_t,)`` or ``(d_t,N_t)`` for ``d_t`` different time/horizontal regression variables. corr (bool, optional): If ``True`` will plot the correlation between points instead of the covariance. Defaults to ``False``. wn (bool, optional): Whether to include white noise in the calculation of covariance. Defaults to ``True``. x_l_plot (JAXArray, optional): The values on the y-axis used by ``plt.pcolormesh`` for the plot. If not included will default to ``x_l`` if ``x_l`` is of shape ``(N_l,)`` or to ``x_l[0, :]`` if ``x_l`` is of shape ``(d_l, N_l)``. x_t_plot (JAXArray, optional): The values on the x-axis used by ``plt.pcolormesh`` for the plot. If not included will default to ``x_t`` if ``x_t`` is of shape ``(N_t,)`` or to ``x_t[0, :]`` if ``x_t`` is of shape ``(d_t, N_t)``. full (bool, optional): If ``True`` will build and visualise the full constructed covariance matrix Returns: plt.Figure: A figure displaying the covariance of each point in the observed data with the selected point located at ``(i, j)`` in the observed data ``Y``. """ # If no x and y axes for the plots specified, defaults to x_l, x_t # If x_l or x_t contain multiple rows then pick the first row if x_l_plot is None: if x_l.ndim == 1: x_l_plot = x_l else: x_l_plot = x_l[0, :] if x_t_plot is None: if x_t.ndim == 1: x_t_plot = x_t else: x_t_plot = x_t[0, :] fig, ax = plt.subplots(2, 2, figsize = (10, 10)) # Build each component matrix Kl = self.Kl(hp, x_l, x_l, wn = wn) Kt = self.Kt(hp, x_t, x_t, wn = wn) Sl = self.Sl(hp, x_l, x_l, wn = wn) St = self.St(hp, x_t, x_t, wn = wn) if full: # Plot full covariance matrix K # Warning: Can be very memory intensive as it builds an array with (N_l*N_t)**2 entries K = jnp.kron(Kl, Kt) + jnp.kron(Sl, St) if corr: K = get_corr_mat(K) fig = plt.imshow(K) else: # Individually plot each of the 4 component covariance matrices if corr: # Convert to correlation matrices if corr = True Kl = get_corr_mat(Kl) Kt = get_corr_mat(Kt) Sl = get_corr_mat(Sl) St = get_corr_mat(St) # Separately plot each of the 4 component matrices ax[0][0].set_title(r"K$_l$") ax[0][0].set_ylabel(r"$x_l$") ax[0][0].set_xlabel(r"$x_l$") img1 = ax[0][0].pcolormesh(x_l_plot, x_l_plot, Kl) ax[0][0].invert_yaxis() plt.colorbar(mappable = img1, ax = ax[0][0]) ax[0][1].set_ylabel(r"$x_t$") ax[0][1].set_xlabel(r"$x_t$") ax[0][1].set_title(r"K$_t$") img2 = ax[0][1].pcolormesh(x_t_plot, x_t_plot, Kt) ax[0][1].invert_yaxis() plt.colorbar(mappable = img2, ax = ax[0][1]) ax[1][0].set_ylabel(r"$x_l$") ax[1][0].set_xlabel(r"$x_l$") ax[1][0].set_title(r"$\Sigma_l$") img3 = ax[1][0].pcolormesh(x_l_plot, x_l_plot, Sl) ax[1][0].invert_yaxis() plt.colorbar(mappable = img3, ax = ax[1][0]) ax[1][1].set_ylabel(r"$x_t$") ax[1][1].set_xlabel(r"$x_t$") ax[1][1].set_title(r"$\Sigma_t$") img4 = ax[1][1].pcolormesh(x_t_plot, x_t_plot, St) ax[1][1].invert_yaxis() plt.colorbar(mappable = img4, ax = ax[1][1]) plt.tight_layout() return fig
[docs] def K_inv_by_vec( self, hp: PyTree, x_l: JAXArray, x_t: JAXArray, R: JAXArray, ) -> JAXArray: r"""Calculates the product of the inverse of the covariance matrix with a vector, represented by a JAXArray of shape ``(N_l, N_t)``. Useful for testing for numerical stability. Args: hp (Pytree): Hyperparameters needed to build the covariance matrices ``Kl``, ``Kt``, ``Sl``, ``St``. Will be unaffected if additional mean function parameters are also included. x_l (JAXArray): Array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape ``(N_l,)`` or ``(d_l,N_l)`` for ``d_l`` different wavelength/vertical regression variables. x_t (JAXArray): Array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape ``(N_t,)`` or ``(d_t,N_t)`` for ``d_t`` different time/horizontal regression variables. R (JAXArray): JAXArray of shape ``(N_l, N_t)`` representing the vector to multiply on the right by the inverse of the covariance matrix ``K``. Returns: JAXArray: The result of multiplying the inverse of the covariance matrix ``K`` by the vector ``R``. """ # Calculate the decomposition of K stored_values = self.decomp_fn(hp, x_l, x_t, stored_values = {}) return K_inv_vec(R, stored_values)
[docs] def K_by_vec( self, hp: PyTree, x_l: JAXArray, x_t: JAXArray, R: JAXArray, ) -> JAXArray: r"""Calculates the product of the covariance matrix with a vector, represented by a JAXArray of shape ``(N_l, N_t)`. Useful for testing for numerical stability. Args: hp (Pytree): Hyperparameters needed to build the covariance matrices ``Kl``, ``Kt``, ``Sl``, ``St``. Will be unaffected if additional mean function parameters are also included. x_l (JAXArray): Array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape ``(N_l,)`` or ``(d_l,N_l)`` for ``d_l`` different wavelength/vertical regression variables. x_t (JAXArray): Array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape ``(N_t,)`` or ``(d_t,N_t)`` for ``d_t`` different time/horizontal regression variables. R (JAXArray): JAXArray of shape ``(N_l, N_t)`` representing the vector to multiply on the right by the covariance matrix ``K``. Returns: JAXArray: The result of multiplying the covariance matrix ``K`` by the vector ``R``. """ Sl = self.Sl(hp, x_l, x_l) St = self.St(hp, x_t, x_t) Kl = self.Kl(hp, x_l, x_l) Kt = self.Kt(hp, x_t, x_t) return kron_prod(Kl, Kt, R) + kron_prod(Sl, St, R)