large_hessian_calc

large_hessian_calc#

luas.jax_convenience_fns.large_hessian_calc(fn: Callable, p: Any, *args, block_size: int | None = 50, return_array: bool | None = True, jit: bool | None = True, **kwargs) Array | Any[source]#

Breaks up the calculation of the hessian of a large matrix into groups of rows to reduce the memory cost. Useful for large data sets when jax.hessian applied to a log likelihood function can cause the code to crash.

This function should work for any arbitrary function however for which jax can calculate second-order derivatives and for functions where the first argument is a PyTree which the derivative is to be calculated with respect to.

Parameters:
  • fn (Callable) – Function to calculate the hessian of. Should be of the form f(p, *args, **kwargs) where p is a PyTree.

  • p (PyTree) – Parameters to calculate the derivative with respect to. All parameters within p will have the derivative be taken with respect to.

  • block_size (int, optional) – The number of groups of rows to calculate the second derivatives for at once. Large numbers will have a higher memory cost but may result in a shorter runtime. Defaults to 50.

  • return_array (bool, optional) – Whether to return the hessian as an array of shape (N_{par}, N_{par}) for N_{par} total parameters in p or as a nested PyTree where the hessian values for parameters named par1 and par2 would be given by hessian_pytree[par1][par2] and hessian_pytree[par2][par1]. Defaults to True.

  • jit (bool, optional) – Whether to JIT compile the hessian function, can speed up the calculation assuming the function can be JIT compiled. Defaults to True.

Returns:

Depending on whether return_array is set to True or False will either return a JAXArray or PyTree giving the hessian with respect to each parameter in p. If returning a JAXArray then the parameters will be ordered in the same order that jax.flatten_util.ravel_pytree will order the input PyTree p.

Return type:

JAXArray or PyTree