GP.laplace_approx

GP.laplace_approx#

GP.laplace_approx(p: Any, Y: Any, regularise: bool | None = True, regularise_const: Any | None = 100.0, vars: list | None = None, fixed_vars: list | None = None, return_array: bool | None = False, large: bool | None = False, large_block_size: int | None = 50, large_jit: bool | None = True, logP_fn: Callable | None = None, hessian_mat: Array | None = None) Tuple[Any | Array, list][source]#

Computes the Laplace approximation at the location of p with options to regularise values which are poorly constrained. The parameters in p should be best-fit values of the posterior.

The Laplace approximation is an estimate of the posterior distribution at the location of best-fit. It assumes the best-fit location is the mean of the Gaussian and calculates the covariance matrix based on approximating the value of the Hessian at the location of best-fit. By taking the negative inverse of the Hessian matrix this should give an approximate covariance matrix assuming the posterior is close to a Gaussian distribution. It is equivalent to a second-order Taylor series approximation of the posterior at the location of best-fit.

The Laplace approximation is useful to get a quick approximation of the posterior without having to run an expensive MCMC calculation. Can also be useful for initialising MCMC inference with a good tuning matrix when large numbers of parameters which may contain strong correlations are being sampled.

Note

This calculation can be memory intensive for large data sets with many free parameters and so setting large = True and ensuring large_block_size is a low integer can help reduce memory costs by breaking up the hessian calculation into blocks of rows.

Parameters:
  • p (PyTree) – Pytree of hyperparameters used to calculate the covariance matrix in addition to any mean function parameters which may be needed to calculate the mean function. Also input to the logPrior function for the calculation of the log priors.

  • Y (JAXArray) – Observed data to fit, must be of shape (N_l, N_t).

  • regularise (bool, optional) – Whether to add a regularisation constant to the diagonal of the hessian matrix corresponding to diagonals which are negative along the diagonals of the resulting covariance matrix. Defaults to True.

  • regularise_const (bool, optional) – The constant added to diagonals of the hessian matrix to regularise it given that regularise is set to True. Defaults to 100.

  • vars (list of str, optional) – The list of keys names corresponding to the parameters we want to calculate the Laplace approximation with respect to. The remaining parameters will be assumed to be fixed. If specified in addition to fixed_vars will raise an Exception.

  • fixed_vars (list of str, optional) – Alternative to vars, may specify instead the parameters being kept fixed which will not be marginalised over in the Laplace approximation. If specified in addition to vars will raise an Exception.

  • return_array (bool, optional) – Whether to return the approximated covariance matrix as a JAXArray or as a nested PyTree where e.g. the covariance between parameters named p1 and p2 is given by cov_mat[p1][p2] and cov_mat[p2][p1].

  • large (bool, optional) – Calculating the hessian matrix for large data sets with many parameters can be very memory intensive. If this is set to True then the hessian will be calculated in groups of rows instead of all at once which reduces the memory cost but can take significantly longer to run. The calculation is otherwise the same with no approximation made. Defaults to False.

  • large_block_size (int, optional) – If large is set to True and the hessian is being calculated in groups of rows can specify how many rows are being calculated simultaneously. Large numbers may calculate the overall hessian faster but at greater memory cost.

  • large_jit (bool, optional) – Whether to JIT compile the hessian function when large = True, can speed up the calculation assuming the function can be JIT compiled. Defaults to True.

  • hessian_mat (JAXArray, optional) – Instead of calculating the hessian matrix (needed for the Laplace approximation) from the input parameters p and Y just provide the hessian matrix directly. Assumed to be a JAXArray and not a PyTree. The input parameters p and Y will be ignored.

Returns:

Returns a tuple of two elements, if return_array = True the first element will be the covariance matrix from the Laplace approximation as a JAXArray, otherwise it will be as a nested PyTree. The second element will be the order of the parameters in the returned covariance matrix if it is a JAXArray. This list is also returned when return_array = False for consistency. The order of the list matches how jax.flatten_util.ravel_pytree will order keys from a PyTree.

Return type:

(JAXArray, list of str) or (PyTree, list of str)