Skip to content

Add SphericalLinear kernel#2742

Merged
gpleiss merged 4 commits into
cornellius-gp:mainfrom
colmont:spherical-linear
Jun 8, 2026
Merged

Add SphericalLinear kernel#2742
gpleiss merged 4 commits into
cornellius-gp:mainfrom
colmont:spherical-linear

Conversation

@colmont

@colmont colmont commented Mar 30, 2026

Copy link
Copy Markdown
Contributor

No description provided.

@gpleiss gpleiss left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally good; a few changes

stereographic projection onto a unit sphere, and :math:`(b_0, b_1)` are learned
mixture weights (via softmax, so :math:`b_0 + b_1 = 1`).

This kernel was proposed in `We Still Don't Understand High-Dimensional Bayesian Optimization`.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: make this into a link, rather than having the arxiv link separate.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self.bounds = bounds

# Learned mixture coefficients: softmax([raw_coeffs]) -> [constant, linear]
self.raw_coeffs = nn.Parameter(torch.zeros(*self.batch_shape, 2))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have getters and setters for the non-raw values, so that they can be initialized appropriately (e.g. self.coeffs = <blah>)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self.raw_coeffs = nn.Parameter(torch.zeros(*self.batch_shape, 2))

# Global lengthscale: sigmoid(raw_glob_ls) * max_sq_norm
self.raw_glob_ls = nn.Parameter(torch.zeros(*self.batch_shape, 1))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ard_num_dims: int | None = None,
lengthscale_prior: Prior | None = None,
lengthscale_constraint: Interval | None = None,
normalize_lengthscale: bool = False,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should set it to true? And we can add a comment that it was set to False in the original paper.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return torch.ones(x1.shape[:-1], dtype=x1.dtype, device=x1.device)

if self.normalize_lengthscale: # Enforce L2 norm = 1
lengthscale = torch.softmax(self.lengthscale, dim=-1).sqrt()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it would be better if we just divided the lengthscale by its norm, so as not to distort the geometry of the lengthscale.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this on a small example (SVM benchmark in high-dim BO): dividing the lengthscale by its norm is about 10x slower than using the softmax, and performance is not necessarily better. This was just a small local test with a few iterations (20-30); I have not yet double-checked this on the cluster for more problems/seeds/iterations. What do you think, worth investigating more or do you already have a strong preference on which reparameterization to go for?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've done some more rigorous benchmarking (the 6 HDBO problems from the paper for 10 seeds each). Softmax constraint is definitely faster than dividing by the norm, and doesn't lead to worse performance (maybe even slightly better?). I would propose keeping softmax constraint and having it on by default.

projected = project_onto_unit_sphere(x)
self.assertEqual(projected.shape, torch.Size([10, 6]))
norms = projected.norm(dim=-1)
self.assertAllClose(norms, torch.ones(10), rtol=1e-5, atol=1e-5)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test that inputs that are on the unit sphere get an effectively identity mapping, to ensure we did the inverse stereographic projection correctly?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

kernel = SphericalLinearKernel(bounds=UNIT_BOUNDS_3D, lengthscale_prior=NormalPrior(0, 1))
pickle.loads(pickle.dumps(kernel))

def test_consistency_square_vs_rectangular(self):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got rid of it

def test_pickle_with_prior(self):
"""Kernel with prior should survive pickle round-trip."""
kernel = SphericalLinearKernel(bounds=UNIT_BOUNDS_3D, lengthscale_prior=NormalPrior(0, 1))
pickle.loads(pickle.dumps(kernel))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got rid of it

bounds = torch.tensor([[0.0, 0.0], [1.0, 1.0]])
kernel = SphericalLinearKernel(bounds=bounds)
loaded = pickle.loads(pickle.dumps(kernel))
self.assertAllClose(loaded.bounds, bounds)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got rid of it

"""Should accept valid priors and reject invalid ones."""
SphericalLinearKernel(bounds=UNIT_BOUNDS_3D, lengthscale_prior=None)
SphericalLinearKernel(bounds=UNIT_BOUNDS_3D, lengthscale_prior=NormalPrior(0, 1))
self.assertRaises(TypeError, SphericalLinearKernel, UNIT_BOUNDS_3D, lengthscale_prior=1)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got rid of it

@colmont

colmont commented Apr 25, 2026

Copy link
Copy Markdown
Contributor Author

I tried to address all of your comments as best as possible, and left a longer reply for the reparameterization question. When playing around with this model for the mujoco-humanoid problem in high-dim BO, my jobs instantly OOMed... even before acquiring the first point. I believe I found a bug in GPyTorch? I think when d >= n, the linear kernels should be routed to DefaultPredictionStrategy instead of LinearPredictionStrategy. I tried to fix this as best as I could in a second commit, I hope it's kinda clear!

@gpleiss gpleiss enabled auto-merge (squash) June 8, 2026 20:05
@gpleiss gpleiss merged commit 7c4ec80 into cornellius-gp:main Jun 8, 2026
7 checks passed
@gpleiss

gpleiss commented Jun 8, 2026

Copy link
Copy Markdown
Member

Thanks @colmont !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants