Source code for luas.numpyro_ext

import numpyro.distributions as dist
from numpyro.distributions.util import promote_shapes, validate_sample
from jax import lax

__all__ = [
    "LuasNumPyro",
]

[docs] class LuasNumPyro(dist.Distribution): """Custom ``NumPyro`` distribution which allows a ``luas.GPClass.GP`` object to be used with ``NumPyro``. Args: gp (object): The ``GP`` object used for log likelihood calculations. var_dict (PyTree): A PyTree of parameter values for calculating the log likelihood likelihood_fn (Callable, optional): Can specify a different log likelihood function than the default of ``GP.logP`` i.e. ``GP.logP_hessianable`` may be used which is more numerically stable if performing hessian/second order derivative calculations. """ def __init__( self, gp=None, var_dict=None, likelihood_fn=None, validate_args=None, ): self.gp = gp self.p = var_dict # If using NumPyro functionality which makes use of the hessian of the log likelihood # (i.e. numpyro.infer.autoguide.AutoLaplaceApproximation) then might be useful to set # likelihood_fn = gp.logP_hessianable as the default log likelihood is faster but not # as numerically stable for second order derivatives. if likelihood_fn is None: self.logP_fn = self.gp.logP else: self.logP_fn = likelihood_fn super().__init__( batch_shape = (), event_shape = (self.gp.N_l, self.gp.N_t), validate_args=validate_args, ) @dist.util.validate_sample def log_prob(self, Y): return self.logP_fn(self.p, Y) def support(self, Y): return dist.constraints.real(Y)