Source code for luas.pymc_ext
import numpy as np
import jax
from jax.flatten_util import ravel_pytree
from .jax_convenience_fns import order_dict
import pymc as pm
from copy import deepcopy
pymc_version = int(pm.__version__[0])
if pymc_version == 4:
# PyMC v4 uses aesara as a backend
import aesara.tensor as tens
from aesara.tensor.var import TensorVariable
from aesara.graph import Apply, Op
elif pymc_version == 5:
# PyMC v5 uses pytensor as a backend
import pytensor.tensor as tens
from pytensor.tensor.var import TensorVariable
from pytensor.graph import Apply, Op
else:
raise Exception(f"PyMC Version {pymc_version} not currently supported. Supported versions are version 4 and version 5.")
__all__ = [
"LuasPyMC",
"LuasPyMCWrapper",
]
[docs]
def LuasPyMC(name, gp = None, var_dict = None, Y = None, likelihood_fn = None, jit = True):
"""PyMC extension which can by used with a luas.GPClass.GP object for log-likelihood calculations.
Args:
name (str): Name of observable storing likelihood values e.g. "log_like".
gp (object): The ``GP`` object used for log likelihood calculations.
var_dict (PyTree): A PyTree of parameter values for calculating the log likelihood.
Y (JAXArray): Data values being fit.
likelihood_fn (Callable, optional): Can specify a different log likelihood function other than the default of
``GP.logP_stored``. Needs to take the same inputs and give the same order of outputs as ``GP.logP_stored``.
jit (bool, optional): Whether to jit compile the likelihood function, as ``PyMC`` does not require the
log likelihood to be jit compiled. Defaults to True.
"""
# Default to using the log posterior method of the gp object which can make use of stored decompositions
# This can save significant time if blocked Gibbs sampling or if some hyperparameters are being fixed.
if likelihood_fn is None:
likelihood_fn = gp.logP_stored
# PyMC requires an array of parameters while the gp object requires a PyTree/dictionary of parameter inputs
# We therefore use the dictionary of variables to construct a make_p_dict function which can convert
# an array of parameters into a PyTree
# First copy the parameters which may include fixed NumPy arrays
dict_to_build = deepcopy(var_dict)
# Loop through each parameter and if it's a TensorVariable (i.e. a parameter being sampled by PyMC)
# replace it with a NumPy array of the same size
for par in var_dict.keys():
if type(var_dict[par]) == TensorVariable:
dict_to_build[par] = np.zeros(var_dict[par].shape.eval()[0])
# Now that we have a PyTree with the right dimensions as the input variables
# we can use jax's ravel_pytree function to create a function which will convert
# an array of parameters into a PyTree of parameters in a deterministic order
test_arr, make_p_dict = ravel_pytree(dict_to_build)
# This is the likelihood function which can take an array of parameters and send to the
# likelihood function as a PyTree of input parameters
logP_fn = lambda p_arr, stored_values: likelihood_fn(make_p_dict(p_arr), Y, stored_values)
# Generate a likelihood function which will return the value of the log-likelihood
# as well as the gradients of each parameter.
# The default likelihood function gp.logP_stored also returns an auxillary output
# which is includes the stored decomposition of the covariance matrix.
value_and_grad_logP_fn = jax.value_and_grad(logP_fn, has_aux = True)
if jit:
value_and_grad_logP_fn = jax.jit(value_and_grad_logP_fn)
# Defines the PyMC wrapper object for log-likelihood functions written in JAX
logP_pymc = LuasPyMCWrapper(value_and_grad_logP_fn)
# Need to sort the variables in the correct order for make_p_dict to construct the right PyTree
par_keys_ordered, par_values_ordered = order_dict(var_dict)
# Makes the array of parameter inputs PyMC will sample
p_pymc = pm.math.concatenate(par_values_ordered)
# Returns the log likelihood using the custom potential provided by PyMC
return pm.Potential(name, logP_pymc(p_pymc))
class LuasPyMCWrapper(Op):
"""Wrapper for log-likelihood calculations used by ``LuasGP`` which depending on the version of ``PyMC``
uses either ``Aesara`` or ``PyTensor``. Taken from the ``PyMC`` tutorial 'How to wrap a JAX function for use in PyMC'
(https://www.pymc.io/projects/examples/en/latest/howto/wrapping_jax_function.html).
"""
default_output = 0
def __init__(self, value_and_grad_logP_fn):
self.stored_values = {} # Stores the decomposition of the covariance matrix
# function which returns the value and the gradients of the log likelihood
self.v_and_g_logP = value_and_grad_logP_fn
def make_node(self, params):
inputs = [tens.as_tensor_variable(params)]
# We now have one output for the function value, and one output for each gradient
outputs = [tens.dscalar()] + [inp.type() for inp in inputs]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
# Returns the value of the log likelihood, stored decomposition of the covariance matrix
# and the gradients of the log likelihood in a single function call
(result, self.stored_values), grad_result = self.v_and_g_logP(*inputs, self.stored_values)
outputs[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)
outputs[1][0] = np.asarray(grad_result, dtype=node.outputs[1].dtype)
def grad(self, inputs, output_gradients):
# The `Op` computes its own gradients, so we call it again.
value = self(*inputs)
# We hid the gradient outputs by setting `default_update=0`, but we
# can retrieve them anytime by accessing the `Apply` node via `value.owner`
grad_result = value.owner.outputs[1]
return [output_gradients[0] * grad_result]