From d9407b6445ae8b7aeaaabbdf36820d32f04dc51c Mon Sep 17 00:00:00 2001 From: anjawa Date: Fri, 10 Apr 2026 11:34:58 +0200 Subject: [PATCH] GibbsKernel added Unittests added + docs updated --- docs/source/kernels.rst | 7 +++ gpytorch/kernels/__init__.py | 2 + gpytorch/kernels/gibbs_kernel.py | 82 +++++++++++++++++++++++++++++++ test/kernels/test_gibbs_kernel.py | 80 ++++++++++++++++++++++++++++++ 4 files changed, 171 insertions(+) create mode 100644 gpytorch/kernels/gibbs_kernel.py create mode 100644 test/kernels/test_gibbs_kernel.py diff --git a/docs/source/kernels.rst b/docs/source/kernels.rst index f16495745..6d73e9789 100644 --- a/docs/source/kernels.rst +++ b/docs/source/kernels.rst @@ -43,6 +43,13 @@ Standard Kernels :members: +:hidden:`GibbsKernel` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: GibbsKernel + :members: + + :hidden:`LinearKernel` ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/gpytorch/kernels/__init__.py b/gpytorch/kernels/__init__.py index 464ead375..8e031341d 100644 --- a/gpytorch/kernels/__init__.py +++ b/gpytorch/kernels/__init__.py @@ -10,6 +10,7 @@ from .cylindrical_kernel import CylindricalKernel from .distributional_input_kernel import DistributionalInputKernel from .gaussian_symmetrized_kl_kernel import GaussianSymmetrizedKLKernel +from .gibbs_kernel import GibbsKernel from .grid_interpolation_kernel import GridInterpolationKernel from .grid_kernel import GridKernel from .hamming_kernel import HammingIMQKernel @@ -49,6 +50,7 @@ "CosineKernel", "DistributionalInputKernel", "GaussianSymmetrizedKLKernel", + "GibbsKernel", "GridKernel", "GridInterpolationKernel", "HammingIMQKernel", diff --git a/gpytorch/kernels/gibbs_kernel.py b/gpytorch/kernels/gibbs_kernel.py new file mode 100644 index 000000000..bbddf1754 --- /dev/null +++ b/gpytorch/kernels/gibbs_kernel.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +from copy import deepcopy + +import torch +from torch import nn + +from .kernel import Kernel + + +class GibbsKernel(Kernel): + r""" + Gibbs kernel with input-dependent lengthscale :math:`\ell(x)` (Gibbs, 1997) + + .. math:: + k(x, x') = \sqrt{\frac{2\ell(x)\ell(x')}{\ell(x)^2 + \ell(x')^2}} + \exp\left(-\frac{(x-x')^2}{\ell(x)^2 + \ell(x')^2}\right) + + :param lengthscale_fn: A callable torch.nn.Module mapping inputs to + positive lengthscales. Must output tensors of shape (... x N x 1) + for input of shape (... x N x D) + :type lengthscale_fn: torch.nn.Module + + Example:: + + class LengthscaleMLP(torch.nn.Module): + def __init__(self, in_dim=1, hidden=32): + super().__init__() + self.net = torch.nn.Sequential( + torch.nn.Linear(in_dim, hidden), + torch.nn.ReLU(), + torch.nn.Linear(hidden, 1), + torch.nn.Softplus(), + ) + + def forward(self, x): + return self.net(x) + + kernel = GibbsKernel(lengthscale_fn=LengthscaleMLP(in_dim=1)) + """ + + is_stationary = False + has_lengthscale = False + + def __init__(self, lengthscale_fn: nn.Module, **kwargs): + if kwargs.get("ard_num_dims") is not None: + raise NotImplementedError("GibbsKernel does not support ARD.") + super().__init__(**kwargs) + self.lengthscale_fn = lengthscale_fn + + # Update batch_shape explicitly: + # Base class derives new batch_shape from parameters, + # but GibbsKernel has none + def __getitem__(self, index): + if len(self.batch_shape) == 0: + return self + new_kernel = deepcopy(self) + index = index if isinstance(index, tuple) else (index,) + new_kernel.batch_shape = torch.empty(self.batch_shape)[index].shape + return new_kernel + + def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params): + x1_eq_x2 = torch.equal(x1, x2) + + l1 = self.lengthscale_fn(x1) + if l1.shape[-1] != 1: + raise ValueError(f"lengthscale_fn must return shape (..., k, 1), got (..., k, {l1.shape[-1]})") + l2 = l1 if x1_eq_x2 else self.lengthscale_fn(x2) + + dist_sq = self.covar_dist(x1, x2, square_dist=True, diag=diag, **params) + + if diag: + S = (l1.pow(2) + l2.pow(2)).squeeze(-1) + prod = (l1 * l2).squeeze(-1) + else: + S = l1.pow(2) + l2.pow(2).transpose(-2, -1) + prod = l1 * l2.transpose(-2, -1) + + prefactor = (2.0 * prod / S).sqrt() + return prefactor * (-dist_sq / S).exp() diff --git a/test/kernels/test_gibbs_kernel.py b/test/kernels/test_gibbs_kernel.py new file mode 100644 index 000000000..d1cc80716 --- /dev/null +++ b/test/kernels/test_gibbs_kernel.py @@ -0,0 +1,80 @@ +import unittest + +import torch +from torch import nn + +from gpytorch.kernels import GibbsKernel, RBFKernel +from gpytorch.test.base_kernel_test_case import BaseKernelTestCase + + +class ConstantLengthscale(nn.Module): + r"""Constant :math:`\ell(x) = \exp(c)`""" + + def __init__(self, value: float = 1.0): + super().__init__() + self.log_value = nn.Parameter(torch.tensor(value).log()) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.log_value.exp().expand(*x.shape[:-1], 1) + + +class MLPLengthscale(nn.Module): + """Small MLP, non-constant lengthscale function.""" + + def __init__(self, in_dim: int = 1, hidden: int = 16): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, hidden), + nn.ReLU(), + nn.Linear(hidden, 1), + ) + nn.init.normal_(self.net[-1].weight, std=0.01) + nn.init.zeros_(self.net[-1].bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.exp(self.net(x)) + + +class TestGibbsKernel(BaseKernelTestCase, unittest.TestCase): + def create_data_no_batch(self): + return torch.randn(50, 10) + + def create_kernel_no_ard(self, **kwargs): + return GibbsKernel(ConstantLengthscale(), **kwargs) + + def setUp(self): + self.lfn = ConstantLengthscale(value=1.0) + self.kernel = GibbsKernel(self.lfn) + + def test_diagonal_is_one(self): + r""":math:`k(x, x) = 1` for all :math:`x`.""" + for lfn in [ConstantLengthscale(), MLPLengthscale(in_dim=2)]: + kernel = GibbsKernel(lfn) + x = torch.randn(20, 2) + K = kernel(x).to_dense() + self.assertTrue(torch.allclose(K.diagonal(), torch.ones(20), atol=1e-5)) + + def test_reduces_to_rbf_with_constant_lengthscale(self): + r"""With constant :math:`\ell(x) = \ell`, Gibbs reduces to RBF.""" + l = 1.5 + kernel_gibbs = GibbsKernel(ConstantLengthscale(value=l)) + kernel_rbf = RBFKernel() + kernel_rbf.lengthscale = l + + x1 = torch.randn(8, 1) + x2 = torch.randn(6, 1) + + K_gibbs = kernel_gibbs(x1, x2).to_dense() + K_rbf = kernel_rbf(x1, x2).to_dense() + + self.assertTrue(torch.allclose(K_gibbs, K_rbf, atol=1e-5)) + + def test_gradient_flows_to_lengthscale_fn(self): + """Gradients propagate through lengthscale_fn.""" + kernel = GibbsKernel(MLPLengthscale(in_dim=2)) + x = torch.randn(8, 2) + kernel(x).to_dense().sum().backward() + + for name, param in kernel.lengthscale_fn.named_parameters(): + self.assertIsNotNone(param.grad, f"No gradient for {name}") + self.assertFalse(torch.all(param.grad == 0), f"Zero gradient for {name}")