pytree_to_array_2D

pytree_to_array_2D#

luas.jax_convenience_fns.pytree_to_array_2D(p: Any, pytree_2D: Array, param_order: list | None = None) Any[source]#

Inverse of array_to_pytree_2D, takes a nested PyTree (e.g. describing a covariance matrix) where the keys correspond to the row and column of a 2D array with a value defined for each parameter with every other parameter. Sorts the nested PyTree into this 2D array sorted according to param_order, defaults to the order jax.flatten_util.ravel_pytree will sort dictionary keys into when forming an array.

Parameters:
  • p (PyTree) – Parameters used for log likelihood calculations, used to describe the order of the array according to how jax.flatten_util.ravel_pytree will flatten into an array.

  • pytree_2D (PyTree) – A nested PyTree describing a 2D array keyed by the pair of parameters along each row and column.

Returns:

A 2D array which the nested PyTree has been sorted into.

Return type:

JAXArray