varying_params_wrapper#
- luas.jax_convenience_fns.varying_params_wrapper(p: Any, vars: list | None = None, fixed_vars: list | None = None, to_numpy: bool | None = True) Tuple[Any, Callable] [source]#
Often useful to take a PyTree of parameters and return a subset of the (key, value) pairs which are to be varied. E.g.
PyMC
requires start values at initialisation of optimisation/inference to only include parameters which are to be fit for.Also returns a function which can take the subset of parameters being varied and return the full set of parameters with the fixed parameters added back in.
Note: By default will return parameter values in NumPy arrays as this is required for inference with
PyMC
. This can be turned off by settingto_numpy = False
.- Parameters:
p (PyTree) – The full set of parameters used for likelihood calculations potentially including both mean function parameters and hyperparameters.
vars (
list
ofstr
, optional) – The list of keys names corresponding to the parameters being varied which we want to include in the output parameter PyTree. If specified in addition to fixed_vars will raise anException
.fixed_vars (
list
ofstr
, optional) – Alternative to vars, may specify instead the parameters being kept fixed which will be excluded from the output parameter PyTree. If specified in addition to vars will raise an Exception.to_numpy (bool, optional) – Converts parameter values from input parameter PyTree to
NumPy
arrays as this is required for inference withPyMC
. Defaults toTrue
.
- Returns:
Returns a tuple where the first element is the input PyTree
p
but containing only the(key, value)
pairs of parameters to be varied, the second element is a function which takes the output parameter PyTree containing only the parameters being varied and adds back in the fixed parameters without overwritting the parameters being varied.- Return type:
(PyTree, Callable)