large_hessian_calc#
- luas.jax_convenience_fns.large_hessian_calc(fn: Callable, p: Any, *args, block_size: int | None = 50, return_array: bool | None = True, jit: bool | None = True, **kwargs) Array | Any [source]#
Breaks up the calculation of the hessian of a large matrix into groups of rows to reduce the memory cost. Useful for large data sets when
jax.hessian
applied to a log likelihood function can cause the code to crash.This function should work for any arbitrary function however for which
jax
can calculate second-order derivatives and for functions where the first argument is a PyTree which the derivative is to be calculated with respect to.- Parameters:
fn (Callable) – Function to calculate the hessian of. Should be of the form
f(p, *args, **kwargs)
wherep
is a PyTree.p (PyTree) – Parameters to calculate the derivative with respect to. All parameters within
p
will have the derivative be taken with respect to.block_size (int, optional) – The number of groups of rows to calculate the second derivatives for at once. Large numbers will have a higher memory cost but may result in a shorter runtime. Defaults to 50.
return_array (bool, optional) – Whether to return the hessian as an array of shape
(N_{par}, N_{par})
forN_{par}
total parameters inp
or as a nested PyTree where the hessian values for parameters named par1 and par2 would be given byhessian_pytree[par1][par2]
andhessian_pytree[par2][par1]
. Defaults to True.jit (bool, optional) – Whether to JIT compile the hessian function, can speed up the calculation assuming the function can be JIT compiled. Defaults to
True
.
- Returns:
Depending on whether
return_array
is set toTrue
orFalse
will either return a JAXArray or PyTree giving the hessian with respect to each parameter inp
. If returning a JAXArray then the parameters will be ordered in the same order thatjax.flatten_util.ravel_pytree
will order the input PyTreep
.- Return type:
JAXArray or PyTree