Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ Standard Kernels
:members:


:hidden:`GibbsKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: GibbsKernel
:members:


:hidden:`LinearKernel`
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .cylindrical_kernel import CylindricalKernel
from .distributional_input_kernel import DistributionalInputKernel
from .gaussian_symmetrized_kl_kernel import GaussianSymmetrizedKLKernel
from .gibbs_kernel import GibbsKernel
from .grid_interpolation_kernel import GridInterpolationKernel
from .grid_kernel import GridKernel
from .hamming_kernel import HammingIMQKernel
Expand Down Expand Up @@ -49,6 +50,7 @@
"CosineKernel",
"DistributionalInputKernel",
"GaussianSymmetrizedKLKernel",
"GibbsKernel",
"GridKernel",
"GridInterpolationKernel",
"HammingIMQKernel",
Expand Down
82 changes: 82 additions & 0 deletions gpytorch/kernels/gibbs_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python3

from __future__ import annotations

from copy import deepcopy

import torch
from torch import nn

from .kernel import Kernel


class GibbsKernel(Kernel):
r"""
Gibbs kernel with input-dependent lengthscale :math:`\ell(x)` (Gibbs, 1997)

.. math::
k(x, x') = \sqrt{\frac{2\ell(x)\ell(x')}{\ell(x)^2 + \ell(x')^2}}
\exp\left(-\frac{(x-x')^2}{\ell(x)^2 + \ell(x')^2}\right)

:param lengthscale_fn: A callable torch.nn.Module mapping inputs to
positive lengthscales. Must output tensors of shape (... x N x 1)
for input of shape (... x N x D)
:type lengthscale_fn: torch.nn.Module

Example::

class LengthscaleMLP(torch.nn.Module):
def __init__(self, in_dim=1, hidden=32):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(in_dim, hidden),
torch.nn.ReLU(),
torch.nn.Linear(hidden, 1),
torch.nn.Softplus(),
)

def forward(self, x):
return self.net(x)

kernel = GibbsKernel(lengthscale_fn=LengthscaleMLP(in_dim=1))
"""

is_stationary = False
has_lengthscale = False

def __init__(self, lengthscale_fn: nn.Module, **kwargs):
if kwargs.get("ard_num_dims") is not None:
raise NotImplementedError("GibbsKernel does not support ARD.")
super().__init__(**kwargs)
self.lengthscale_fn = lengthscale_fn

# Update batch_shape explicitly:
# Base class derives new batch_shape from parameters,
# but GibbsKernel has none
def __getitem__(self, index):
if len(self.batch_shape) == 0:
return self
new_kernel = deepcopy(self)
index = index if isinstance(index, tuple) else (index,)
new_kernel.batch_shape = torch.empty(self.batch_shape)[index].shape
return new_kernel

def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params):
x1_eq_x2 = torch.equal(x1, x2)

l1 = self.lengthscale_fn(x1)
if l1.shape[-1] != 1:
raise ValueError(f"lengthscale_fn must return shape (..., k, 1), got (..., k, {l1.shape[-1]})")
l2 = l1 if x1_eq_x2 else self.lengthscale_fn(x2)

dist_sq = self.covar_dist(x1, x2, square_dist=True, diag=diag, **params)

if diag:
S = (l1.pow(2) + l2.pow(2)).squeeze(-1)
prod = (l1 * l2).squeeze(-1)
else:
S = l1.pow(2) + l2.pow(2).transpose(-2, -1)
prod = l1 * l2.transpose(-2, -1)

prefactor = (2.0 * prod / S).sqrt()
return prefactor * (-dist_sq / S).exp()
80 changes: 80 additions & 0 deletions test/kernels/test_gibbs_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import unittest

import torch
from torch import nn

from gpytorch.kernels import GibbsKernel, RBFKernel
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase


class ConstantLengthscale(nn.Module):
r"""Constant :math:`\ell(x) = \exp(c)`"""

def __init__(self, value: float = 1.0):
super().__init__()
self.log_value = nn.Parameter(torch.tensor(value).log())

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.log_value.exp().expand(*x.shape[:-1], 1)


class MLPLengthscale(nn.Module):
"""Small MLP, non-constant lengthscale function."""

def __init__(self, in_dim: int = 1, hidden: int = 16):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, 1),
)
nn.init.normal_(self.net[-1].weight, std=0.01)
nn.init.zeros_(self.net[-1].bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.exp(self.net(x))


class TestGibbsKernel(BaseKernelTestCase, unittest.TestCase):
def create_data_no_batch(self):
return torch.randn(50, 10)

def create_kernel_no_ard(self, **kwargs):
return GibbsKernel(ConstantLengthscale(), **kwargs)

def setUp(self):
self.lfn = ConstantLengthscale(value=1.0)
self.kernel = GibbsKernel(self.lfn)

def test_diagonal_is_one(self):
r""":math:`k(x, x) = 1` for all :math:`x`."""
for lfn in [ConstantLengthscale(), MLPLengthscale(in_dim=2)]:
kernel = GibbsKernel(lfn)
x = torch.randn(20, 2)
K = kernel(x).to_dense()
self.assertTrue(torch.allclose(K.diagonal(), torch.ones(20), atol=1e-5))

def test_reduces_to_rbf_with_constant_lengthscale(self):
r"""With constant :math:`\ell(x) = \ell`, Gibbs reduces to RBF."""
l = 1.5
kernel_gibbs = GibbsKernel(ConstantLengthscale(value=l))
kernel_rbf = RBFKernel()
kernel_rbf.lengthscale = l

x1 = torch.randn(8, 1)
x2 = torch.randn(6, 1)

K_gibbs = kernel_gibbs(x1, x2).to_dense()
K_rbf = kernel_rbf(x1, x2).to_dense()

self.assertTrue(torch.allclose(K_gibbs, K_rbf, atol=1e-5))

def test_gradient_flows_to_lengthscale_fn(self):
"""Gradients propagate through lengthscale_fn."""
kernel = GibbsKernel(MLPLengthscale(in_dim=2))
x = torch.randn(8, 2)
kernel(x).to_dense().sum().backward()

for name, param in kernel.lengthscale_fn.named_parameters():
self.assertIsNotNone(param.grad, f"No gradient for {name}")
self.assertFalse(torch.all(param.grad == 0), f"Zero gradient for {name}")
Loading