import jax.numpy as jnp
from jax import vmap
from typing import Callable
from .luas_types import JAXArray, Scalar
__all__ = [
"evaluate_kernel",
"distanceL1",
"distanceL2Sq",
"exp",
"squared_exp",
"matern32",
"matern52",
"rational_quadratic",
"exp_sine_squared",
"cosine",
]
def evaluate_kernel(kernel_fn: Callable, x: JAXArray, y: JAXArray, *args) -> JAXArray:
"""Uses the ``jax.vmap`` function to efficiently build the covariance matrix from
a given kernel function.
Args:
kernel_fn (Callable): The desired kernel function to use
x (JAXArray): Input vector 1
y (JAXArray): Input vector 2
l (Scalar): Length scale
Returns:
JAXArray: The constructed covariance matrix
"""
K = vmap(lambda x1: vmap(lambda y1: kernel_fn(x1, y1, *args))(y))(x)
return K
def distanceL1(x: JAXArray, y: JAXArray, L: Scalar) -> JAXArray:
r"""Evaluates the L1 norm of two input vectors divided by a length scale.
.. math::
L1(x, y) = \frac{|x - y|}{L}
Args:
x (JAXArray): Input vector 1
y (JAXArray): Input vector 2
L (Scalar): Length scale
Returns:
Scalar: L1 norm between two input vectors
"""
return jnp.sum(jnp.abs(x - y)/L)
def distanceL2Sq(x: JAXArray, y: JAXArray, L: Scalar) -> JAXArray:
r"""Evaluates the Squared L2 norm of two input vectors divided by the length scale ``L``.
.. math::
L2^2(x, y) = \frac{|x - y|^2}{L^2}
Args:
x (JAXArray): Input vector 1
y (JAXArray): Input vector 2
L (Scalar): Length scale
Returns:
Scalar: L2 norm between two input vectors
"""
return jnp.sum(jnp.square(x - y)/L**2)
[docs]
def squared_exp(x: JAXArray, y: JAXArray, L: Scalar) -> JAXArray:
r"""Squared exponential kernel function, also known as the radial basis function,
used with ``luas.kernels.evaluate_kernel`` to build a covariance matrix.
.. math::
k(x, y) = \exp\Bigg( -\frac{|x - y|^2}{2L^2}\Bigg)
Args:
x (JAXArray): Input vector 1
y (JAXArray): Input vector 2
L (Scalar): Length scale
Returns:
Scalar: Covariance between two input vectors
"""
return evaluate_kernel(squared_exp_calc, x, y, L)
def squared_exp_calc(x: JAXArray, y: JAXArray, L: Scalar) -> JAXArray:
"""Function used by ``luas.kernels.squared_exp`` to evaluate the squared exponential kernel function.
"""
tau_sq = distanceL2Sq(x, y, L)
return jnp.exp(-0.5 * tau_sq.sum())
[docs]
def exp(x: JAXArray, y: JAXArray, L: Scalar) -> JAXArray:
r"""Exponential kernel function, also known as the Matern 1/2 kernel, used with ``luas.kernels.evaluate_kernel``
to build covariance matrices.
.. math::
k(x, y) = \Bigg(\frac{|x - y|}{L}\Bigg)
Args:
x (JAXArray): Input vector 1
y (JAXArray): Input vector 2
L (Scalar): Length scale
Returns:
Scalar: Covariance between two input vectors
"""
return evaluate_kernel(exp_calc, x, y, L)
def exp_calc(x: JAXArray, y: JAXArray, L: Scalar) -> JAXArray:
"""Function used by ``luas.kernels.exp`` to evaluate the Exponential kernel function.
"""
delta_t = distanceL1(x, y, L).sum()
return jnp.exp(-delta_t)
[docs]
def matern32(x: JAXArray, y: JAXArray, L: Scalar) -> JAXArray:
r"""Matern 3/2 kernel function, used with ``luas.kernels.evaluate_kernel``
to build covariance matrices.
.. math::
k(x, y) = \Bigg(1 + \sqrt{3} \frac{|x - y|}{L}\Bigg) \exp\Bigg( -\sqrt{3} \frac{|x - y|}{L}\Bigg)
Args:
x (JAXArray): Input vector 1
y (JAXArray): Input vector 2
L (Scalar): Length scale
Returns:
Scalar: Covariance between two input vectors
"""
return evaluate_kernel(matern32_calc, x, y, L)
def matern32_calc(x: JAXArray, y: JAXArray, L: Scalar) -> JAXArray:
"""Function used by matern32 to evaluate the Matern 3/2 kernel function.
"""
delta_t = jnp.sqrt(3)*distanceL1(x, y, L).sum()
return (1+delta_t)*jnp.exp(-delta_t)
[docs]
def matern52(x: JAXArray, y: JAXArray, L: Scalar) -> JAXArray:
r"""Matern 5/2 kernel function, used with ``luas.kernels.evaluate_kernel``
to build covariance matrices.
.. math::
k(x, y) = \Bigg(1 + \sqrt{5} \frac{|x - y|}{L} + \frac{5|x - y|^2}{3L^2}\Bigg) \exp\Bigg( -\sqrt{5}\frac{|x - y|}{L}\Bigg)
Args:
x (JAXArray): Input vector 1
y (JAXArray): Input vector 2
L (Scalar): Length scale
Returns:
Scalar: Covariance between two input vectors
"""
return evaluate_kernel(matern52_calc, x, y, L)
def matern52_calc(x: JAXArray, y: JAXArray, L: Scalar) -> JAXArray:
"""Function used by matern52 to evaluate the Matern 5/2 kernel function.
"""
delta_t = jnp.sqrt(5)*distanceL1(x, y, L).sum()
return (1+delta_t+jnp.square(delta_t)/3)*jnp.exp(-delta_t)
[docs]
def rational_quadratic(x: JAXArray, y: JAXArray, L: Scalar, alpha: Scalar) -> JAXArray:
r"""Rational quadratic kernel function, used with ``luas.kernels.evaluate_kernel``
to build covariance matrices.
.. math::
k(x, y) = \Bigg(1 + \frac{|x - y|^2}{2 \alpha L^2}\Bigg)^{-\alpha}
Args:
x (JAXArray): Input vector 1
y (JAXArray): Input vector 2
L (Scalar): Length scale
alpha (Scalar): Scale mixture parameter
Returns:
Scalar: Covariance between two input vectors
"""
return evaluate_kernel(rational_quadratic_calc, x, y, L, alpha)
def rational_quadratic_calc(x: JAXArray, y: JAXArray, L: Scalar, alpha: Scalar) -> JAXArray:
"""Function used by rational_quadratic to evaluate the rational quadratic kernel function.
"""
tau_sq = distanceL2Sq(x, y, L).sum()
return (1. + 0.5*tau_sq/alpha)**(-alpha)
[docs]
def exp_sine_squared(x: JAXArray, y: JAXArray, L: Scalar, P: Scalar) -> JAXArray:
r"""Exponential sine squared kernel, used with evaluate_kernel
to build covariance matrices which have periodic covariance.
.. math::
k(x, y) = \exp\Bigg( -\frac{2 \sin^2(\pi(x - y)/P)}{L^2}\Bigg)
Args:
x (JAXArray): Input vector 1
y (JAXArray): Input vector 2
L (Scalar): Length scale
P (Scalar): Period
Returns:
JAXArray: Covariance between two input vectors
"""
return evaluate_kernel(exp_sine_squared_calc, x, y, L, P)
def exp_sine_squared_calc(x: JAXArray, y: JAXArray, L: Scalar, P: Scalar) -> JAXArray:
"""Function used by exp_sine_squared to evaluate the exponential sine squared kernel function.
"""
tau_sq = (jnp.sum(jnp.square(jnp.sin(jnp.pi*(x - y)/P)/L))).sum()
return jnp.exp(-2.0 * tau_sq)
[docs]
def cosine(x: JAXArray, y: JAXArray, P: Scalar) -> JAXArray:
r"""Cosine kernel, used with ``luas.kernels.evaluate_kernel``
to build covariance matrices which have periodic covariance.
.. math::
k(x, y) = \cos\Bigg(\frac{2\pi|x - y|}{P}\Bigg)
Args:
x (JAXArray): Input vector 1
y (JAXArray): Input vector 2
P (Scalar): Period
Returns:
JAXArray: Covariance between two input vectors
"""
return evaluate_kernel(cosine_calc, x, y, P)
def cosine_calc(x: JAXArray, y: JAXArray, P: Scalar) -> JAXArray:
"""Function used by cosine to evaluate the cosine kernel function.
"""
delta_t = distanceL1(x, y, P).sum()
return jnp.cos(2*jnp.pi*delta_t)