GeneralKernel.predict#
- GeneralKernel.predict(hp: Any, x_l: Array, x_l_pred: Array, x_t: Array, x_t_pred: Array, R: Array, M_s: Array, wn: bool | None = True, return_std_dev: bool | None = True) Tuple[Array, Array, Array][source]#
Performs GP regression and computes the GP predictive mean and the GP predictive uncertainty as the standard devation at each location or else can return the full covariance matrix. Requires the input kernel function
Kto have awnkeyword argument that defines the kernel when white noise is included (wn = True) and when white noise isn’t included (wn = False).Currently assumes the same input hyperparameters for both the observed and predicted locations. The predicted locations
x_l_predandx_t_predmay deviate from the observed locationsx_landx_thowever.The GP predictive mean is defined as:
\[\mathbb{E}[\vec{y}_*] = \vec{\mu}_* + \mathbf{K}_*^T \mathbf{K}^{-1} \vec{r}\]And the GP predictive covariance is given by:
\[Var[\vec{y}_*] = \mathbf{K}_{**} - \mathbf{K}_*^T \mathbf{K}^{-1} \mathbf{K}_*\]- Parameters:
hp (Pytree) – Hyperparameters needed to build the covariance matrix
K. Will be unaffected if additional mean function parameters are also included.x_l (JAXArray) – Array containing wavelength/vertical dimension regression variable(s) for the observed locations. May be of shape
(N_l,)or(d_l,N_l)ford_ldifferent wavelength/vertical regression variables.x_l_pred (JAXArray) – Array containing wavelength/vertical dimension regression variable(s) for the prediction locations (which may be the same as the observed locations). May be of shape
(N_l_pred,)or(d_l,N_l_pred)ford_ldifferent wavelength/vertical regression variables.x_t (JAXArray) – Array containing time/horizontal dimension regression variable(s) for the observed locations. May be of shape
(N_t,)or(d_t,N_t)ford_tdifferent time/horizontal regression variables.x_t_pred (JAXArray) – Array containing time/horizontal dimension regression variable(s) for the prediction locations (which may be the same as the observed locations). May be of shape
(N_t,)or(d_t,N_t)ford_tdifferent time/horizontal regression variables.R (JAXArray) – Residuals to be fit, equal to the observed data minus the deterministic mean function. Must have the same shape as the observed data
(N_l, N_t).M_s (JAXArray) – Mean function evaluated at the locations of the predictions
x_l_pred,x_t_pred. Must have shape(N_l_pred, N_t_pred)whereN_l_predis the number of wavelength/vertical dimension predictions andN_t_predthe number of time/horizontal dimension predictions.wn (bool, optional) – Whether to include white noise in the uncertainty at the predicted locations. Defaults to True.
return_std_dev (bool, optional) – If
Truewill return the standard deviation of uncertainty at the predicted locations. Otherwise will return the full predictive covariance matrix. Defaults to True.
- Returns:
Returns a tuple of two elements, where the first element is the GP predictive mean at the prediction locations, the second element is either the standard deviation of the predictions if
return_std_dev = True, otherwise it will be the full covariance matrix of the predicted values.- Return type:
(JAXArray, JAXArray)