varying_params_wrapper

varying_params_wrapper#

luas.jax_convenience_fns.varying_params_wrapper(p: Any, vars: list | None = None, fixed_vars: list | None = None, to_numpy: bool | None = True) Tuple[Any, Callable][source]#

Often useful to take a PyTree of parameters and return a subset of the (key, value) pairs which are to be varied. E.g. PyMC requires start values at initialisation of optimisation/inference to only include parameters which are to be fit for.

Also returns a function which can take the subset of parameters being varied and return the full set of parameters with the fixed parameters added back in.

Note: By default will return parameter values in NumPy arrays as this is required for inference with PyMC. This can be turned off by setting to_numpy = False.

Parameters:
  • p (PyTree) – The full set of parameters used for likelihood calculations potentially including both mean function parameters and hyperparameters.

  • vars (list of str, optional) – The list of keys names corresponding to the parameters being varied which we want to include in the output parameter PyTree. If specified in addition to fixed_vars will raise an Exception.

  • fixed_vars (list of str, optional) – Alternative to vars, may specify instead the parameters being kept fixed which will be excluded from the output parameter PyTree. If specified in addition to vars will raise an Exception.

  • to_numpy (bool, optional) – Converts parameter values from input parameter PyTree to NumPy arrays as this is required for inference with PyMC. Defaults to True.

Returns:

Returns a tuple where the first element is the input PyTree p but containing only the (key, value) pairs of parameters to be varied, the second element is a function which takes the output parameter PyTree containing only the parameters being varied and adds back in the fixed parameters without overwritting the parameters being varied.

Return type:

(PyTree, Callable)