varying_params_wrapper#
- luas.jax_convenience_fns.varying_params_wrapper(p: Any, vars: list | None = None, fixed_vars: list | None = None, to_numpy: bool | None = True) Tuple[Any, Callable][source]#
Often useful to take a PyTree of parameters and return a subset of the (key, value) pairs which are to be varied. E.g.
PyMCrequires start values at initialisation of optimisation/inference to only include parameters which are to be fit for.Also returns a function which can take the subset of parameters being varied and return the full set of parameters with the fixed parameters added back in.
Note: By default will return parameter values in NumPy arrays as this is required for inference with
PyMC. This can be turned off by settingto_numpy = False.- Parameters:
p (PyTree) – The full set of parameters used for likelihood calculations potentially including both mean function parameters and hyperparameters.
vars (
listofstr, optional) – The list of keys names corresponding to the parameters being varied which we want to include in the output parameter PyTree. If specified in addition to fixed_vars will raise anException.fixed_vars (
listofstr, optional) – Alternative to vars, may specify instead the parameters being kept fixed which will be excluded from the output parameter PyTree. If specified in addition to vars will raise an Exception.to_numpy (bool, optional) – Converts parameter values from input parameter PyTree to
NumPyarrays as this is required for inference withPyMC. Defaults toTrue.
- Returns:
Returns a tuple where the first element is the input PyTree
pbut containing only the(key, value)pairs of parameters to be varied, the second element is a function which takes the output parameter PyTree containing only the parameters being varied and adds back in the fixed parameters without overwritting the parameters being varied.- Return type:
(PyTree, Callable)