LuasKernel.predict

LuasKernel.predict#

LuasKernel.predict(hp: Any, x_l: Array, x_l_pred: Array, x_t: Array, x_t_pred: Array, R: Array, M_s: Array, wn=True, return_std_dev=True) Tuple[Array, Array][source]#

Performs GP regression and computes the GP predictive mean and the GP predictive uncertainty as the standard devation at each location or else can return the full covariance matrix. Requires the input kernel function K to have a wn keyword argument that defines the kernel when white noise is included (wn = True) and when white noise isn’t included (wn = False).

Currently assumes the same input hyperparameters for both the observed and predicted locations. The predicted locations x_l_pred and x_t_pred may deviate from the observed locations x_l and x_t however.

The GP predictive mean is defined as:

\[\mathbb{E}[\vec{y}_*] = \vec{\mu}_* + \mathbf{K}_*^T \mathbf{K}^{-1} \vec{r}\]

And the GP predictive covariance is given by:

\[Var[\vec{y}_*] = \mathbf{K}_{**} - \mathbf{K}_*^T \mathbf{K}^{-1} \mathbf{K}_*\]

Note

The calculation of the full predictive covariance matrix when return_std_dev = False is still experimental and may come with numerically stability issues. It is also very memory intensive and may cause code to crash. Future updates to luas may improve this.

Parameters:
  • hp (Pytree) – Hyperparameters needed to build the covariance matrices Kl, Kt, Sl, St. Will be unaffected if additional mean function parameters are also included.

  • x_l (JAXArray) – Array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape (N_l,) or (d_l,N_l) for d_l different wavelength/vertical regression variables.

  • x_l_pred (JAXArray) – Array containing wavelength/vertical dimension regression variable(s) for the prediction locations (which may be the same as the observed locations). May be of shape (N_l_pred,) or (d_l,N_l_pred) for d_l different wavelength/vertical regression variables.

  • x_t (JAXArray) – Array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape (N_t,) or (d_t,N_t) for d_t different time/horizontal regression variables.

  • x_t_pred (JAXArray) – Array containing time/horizontal dimension regression variable(s) for the prediction locations (which may be the same as the observed locations). May be of shape (N_t,) or (d_t,N_t) for d_t different time/horizontal regression variables.

  • R (JAXArray) – Residuals to be fit, equal to the observed data minus the deterministic mean function. Must have the same shape as the observed data (N_l, N_t).

  • M_s (JAXArray) – Mean function evaluated at the locations of the predictions x_l_pred, x_t_pred. Must have shape (N_l_pred, N_t_pred) where N_l_pred is the number of wavelength/vertical dimension predictions and N_t_pred the number of time/horizontal dimension predictions.

  • wn (bool, optional) – Whether to include white noise in the uncertainty at the predicted locations. Defaults to True.

  • return_std_dev (bool, optional) – If True will return the standard deviation of uncertainty at the predicted locations. Otherwise will return the full predictive covariance matrix. Defaults to True.

Returns:

Returns a tuple of two elements, where the first element is the GP predictive mean at the prediction locations, the second element is either the standard deviation of the predictions if return_std_dev = True, otherwise it will be the full covariance matrix of the predicted values.

Return type:

(JAXArray, JAXArray)