GeneralKernel.predict#
- GeneralKernel.predict(hp: Any, x_l: Array, x_l_pred: Array, x_t: Array, x_t_pred: Array, R: Array, M_s: Array, wn: bool | None = True, return_std_dev: bool | None = True) Tuple[Array, Array, Array] [source]#
Performs GP regression and computes the GP predictive mean and the GP predictive uncertainty as the standard devation at each location or else can return the full covariance matrix. Requires the input kernel function
K
to have awn
keyword argument that defines the kernel when white noise is included (wn = True
) and when white noise isn’t included (wn = False
).Currently assumes the same input hyperparameters for both the observed and predicted locations. The predicted locations
x_l_pred
andx_t_pred
may deviate from the observed locationsx_l
andx_t
however.The GP predictive mean is defined as:
\[\mathbb{E}[\vec{y}_*] = \vec{\mu}_* + \mathbf{K}_*^T \mathbf{K}^{-1} \vec{r}\]And the GP predictive covariance is given by:
\[Var[\vec{y}_*] = \mathbf{K}_{**} - \mathbf{K}_*^T \mathbf{K}^{-1} \mathbf{K}_*\]- Parameters:
hp (Pytree) – Hyperparameters needed to build the covariance matrix
K
. Will be unaffected if additional mean function parameters are also included.x_l (JAXArray) – Array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape
(N_l,)
or(d_l,N_l)
ford_l
different wavelength/vertical regression variables.x_l_pred (JAXArray) – Array containing wavelength/vertical dimension regression variable(s) for the prediction locations (which may be the same as the observed locations). May be of shape
(N_l_pred,)
or(d_l,N_l_pred)
ford_l
different wavelength/vertical regression variables.x_t (JAXArray) – Array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape
(N_t,)
or(d_t,N_t)
ford_t
different time/horizontal regression variables.x_t_pred (JAXArray) – Array containing time/horizontal dimension regression variable(s) for the prediction locations (which may be the same as the observed locations). May be of shape
(N_t,)
or(d_t,N_t)
ford_t
different time/horizontal regression variables.R (JAXArray) – Residuals to be fit, equal to the observed data minus the deterministic mean function. Must have the same shape as the observed data
(N_l, N_t)
.M_s (JAXArray) – Mean function evaluated at the locations of the predictions
x_l_pred
,x_t_pred
. Must have shape(N_l_pred, N_t_pred)
whereN_l_pred
is the number of wavelength/vertical dimension predictions andN_t_pred
the number of time/horizontal dimension predictions.wn (bool, optional) – Whether to include white noise in the uncertainty at the predicted locations. Defaults to True.
return_std_dev (bool, optional) – If
True
will return the standard deviation of uncertainty at the predicted locations. Otherwise will return the full predictive covariance matrix. Defaults to True.
- Returns:
Returns a tuple of two elements, where the first element is the GP predictive mean at the prediction locations, the second element is either the standard deviation of the predictions if
return_std_dev = True
, otherwise it will be the full covariance matrix of the predicted values.- Return type:
(JAXArray, JAXArray)