jax_convenience_fns

jax_convenience_fns#

array_to_pytree_2D(p, arr_2D)

Takes a 2D JAXArray (e.g. a covariance matrix) where the rows and columns are sorted according to jax.flatten_util.ravel_pytree would sort the input PyTree of parameters p and returns a nested PyTree of the array sorted into the parameter values.

pytree_to_array_2D(p, pytree_2D[, param_order])

Inverse of array_to_pytree_2D, takes a nested PyTree (e.g. describing a covariance matrix) where the keys correspond to the row and column of a 2D array with a value defined for each parameter with every other parameter.

varying_params_wrapper(p[, vars, ...])

Often useful to take a PyTree of parameters and return a subset of the (key, value) pairs which are to be varied.

large_hessian_calc(fn, p, *args[, ...])

Breaks up the calculation of the hessian of a large matrix into groups of rows to reduce the memory cost.

transf_from_unbounded_params(p_unbounded, ...)

Inverse of the transformation in transf_to_unbounded_params.

transf_to_unbounded_params(p, param_bounds)

Replicates the transformation used by PyMC and NumPyro to convert a PyTree of parameters - some or all of which lie within the bounds given by param_bounds - to unbounded parameters using the transformation given in transf_to_unbounded.