array_to_pytree_2D#
- luas.jax_convenience_fns.array_to_pytree_2D(p: Any, arr_2D: Array) Any [source]#
Takes a 2D JAXArray (e.g. a covariance matrix) where the rows and columns are sorted according to
jax.flatten_util.ravel_pytree
would sort the input PyTree of parametersp
and returns a nested PyTree of the array sorted into the parameter values.- Parameters:
p (PyTree) – Parameters used for log likelihood calculations, used to describe the order of the array according to how
jax.flatten_util.ravel_pytree
will flatten into an array.arr_2D (JAXArray) – A 2D array where the rows and columns are both sorted in the order in which
jax.flatten_util.ravel_pytree
flattens the parameter PyTree p.
- Returns:
A nested PyTree where the input 2D array
arr_2D
has been rearranged to its corresponding parameters.- Return type:
PyTree