pytree_to_array_2D#
- luas.jax_convenience_fns.pytree_to_array_2D(p: Any, pytree_2D: Array, param_order: list | None = None) Any [source]#
Inverse of
array_to_pytree_2D
, takes a nested PyTree (e.g. describing a covariance matrix) where the keys correspond to the row and column of a 2D array with a value defined for each parameter with every other parameter. Sorts the nested PyTree into this 2D array sorted according to param_order, defaults to the orderjax.flatten_util.ravel_pytree
will sort dictionary keys into when forming an array.- Parameters:
p (PyTree) – Parameters used for log likelihood calculations, used to describe the order of the array according to how
jax.flatten_util.ravel_pytree
will flatten into an array.pytree_2D (PyTree) – A nested PyTree describing a 2D array keyed by the pair of parameters along each row and column.
- Returns:
A 2D array which the nested PyTree has been sorted into.
- Return type:
JAXArray