array_to_pytree_2D

array_to_pytree_2D#

luas.jax_convenience_fns.array_to_pytree_2D(p: Any, arr_2D: Array) Any[source]#

Takes a 2D JAXArray (e.g. a covariance matrix) where the rows and columns are sorted according to jax.flatten_util.ravel_pytree would sort the input PyTree of parameters p and returns a nested PyTree of the array sorted into the parameter values.

Parameters:
  • p (PyTree) – Parameters used for log likelihood calculations, used to describe the order of the array according to how jax.flatten_util.ravel_pytree will flatten into an array.

  • arr_2D (JAXArray) – A 2D array where the rows and columns are both sorted in the order in which jax.flatten_util.ravel_pytree flattens the parameter PyTree p.

Returns:

A nested PyTree where the input 2D array arr_2D has been rearranged to its corresponding parameters.

Return type:

PyTree