GP.predict

Contents

GP.predict#

GP.predict(p: Any, Y: Array, x_l_pred: Array | None = None, x_t_pred: Array | None = None, return_std_dev: bool | None = True, **kwargs) Tuple[Array, Array, Array][source]#

Performs GP regression and computes the GP predictive mean and the GP predictive uncertainty as the standard devation at each location or else can return the full covariance matrix. Requires the input kernel function(s) to have a wn keyword argument that defines the kernel when white noise is included (wn = True) and when white noise isn’t included (wn = False).

Currently assumes the same input hyperparameters for both the observed and predicted locations. The predicted locations x_l_pred and x_t_pred may deviate from the observed locations x_l and x_t however.

The GP predictive mean is defined as:

\[\mathbb{E}[\vec{y}_*] = \vec{\mu}_* + \mathbf{K}_*^T \mathbf{K}^{-1} \vec{r}\]

And the GP predictive covariance is given by:

\[Var[\vec{y}_*] = \mathbf{K}_{**} - \mathbf{K}_*^T \mathbf{K}^{-1} \mathbf{K}_*\]
Parameters:
  • p (PyTree) – Pytree of hyperparameters used to calculate the covariance matrix in addition to any mean function parameters which may be needed to calculate the mean function.

  • Y (JAXArray) – Observed data to fit, must be of shape (N_l, N_t).

  • x_l_pred (JAXArray, optional) – Prediction locations along the row dimension, defaults to observed input locations.

  • x_t_pred (JAXArray, optional) – Prediction locations along the column dimension, defaults to observed input locations.

  • return_std_dev (bool, optional) – If True will return the standard deviation ofuncertainty at the predicted locations. Otherwise will return the full predictive covariance matrix. Defaults to True.

Returns:

Returns a tuple of three elements, where the first element is the GP predictive mean at the prediction locations, the second element is either the standard deviation of the predictions if return_std_dev = True, otherwise it will be the full covariance matrix of the predicted values. The third element will be the mean function evalulated at the prediction locations.

Return type:

(JAXArray, JAXArray, JAXArray)