array_to_pytree_2D (p, arr_2D)
|
Takes a 2D JAXArray (e.g. a covariance matrix) where the rows and columns are sorted according to jax.flatten_util.ravel_pytree would sort the input PyTree of parameters p and returns a nested PyTree of the array sorted into the parameter values. |
pytree_to_array_2D (p, pytree_2D[, param_order])
|
Inverse of array_to_pytree_2D , takes a nested PyTree (e.g. describing a covariance matrix) where the keys correspond to the row and column of a 2D array with a value defined for each parameter with every other parameter. |
varying_params_wrapper (p[, vars, ...])
|
Often useful to take a PyTree of parameters and return a subset of the (key, value) pairs which are to be varied. |
large_hessian_calc (fn, p, *args[, ...])
|
Breaks up the calculation of the hessian of a large matrix into groups of rows to reduce the memory cost. |
transf_from_unbounded_params (p_unbounded, ...)
|
Inverse of the transformation in transf_to_unbounded_params. |
transf_to_unbounded_params (p, param_bounds)
|
Replicates the transformation used by PyMC and NumPyro to convert a PyTree of parameters - some or all of which lie within the bounds given by param_bounds - to unbounded parameters using the transformation given in transf_to_unbounded . |