From 863b80415a10734cc0b8481cbcb71d8713a6b62d Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 19 Mar 2026 15:26:55 -0400 Subject: [PATCH 1/2] Fix memory leak in DefaultPredictionStrategy cache hooks (#2631) Root cause ---------- When `detach_test_caches` is False (as used by BoTorch via `propagate_grads(True)`), the `_mean_cache` and `_exact_predictive_covar_inv_quad_form_cache` properties register a `clear_cache_hook` on the cached tensor's `grad_fn`: wrapper = functools.partial(clear_cache_hook, self) mean_cache.grad_fn.register_hook(wrapper) This creates a reference cycle: prediction_strategy -> _memoize_cache -> cached tensor -> grad_fn (C++ object) -> hook closure -> prediction_strategy Because PyTorch's `grad_fn` is a C++ object, Python's cycle garbage collector cannot traverse the C++/Python boundary to detect the cycle. The entire chain becomes uncollectable. For fantasy models, the cached tensor's computation graph holds references back to the parent model's caches, so each iteration adds another uncollectable chain, causing memory to grow indefinitely until OOM. Fix --- Added `register_cache_clear_hook(tsr, module)` to `gpytorch/utils/memoize.py`. This helper uses `weakref.ref(module)` when registering the backward hook on `tsr.grad_fn`, breaking the reference cycle. When no external strong references to the prediction strategy remain, it and its caches are garbage-collected normally, and the hook becomes a no-op. Updated all three call sites to use the new helper: - `DefaultPredictionStrategy._exact_predictive_covar_inv_quad_form_cache` - `DefaultPredictionStrategy._mean_cache` - `_variational_strategy._add_cache_hook` Validation ---------- Repro script (200 iterations of fantasy model creation + evaluation with `detach_test_caches(False)`): import torch, tracemalloc from gpytorch import settings as gpt_settings from gpytorch.distributions import MultivariateNormal from gpytorch.kernels import RBFKernel from gpytorch.likelihoods import GaussianLikelihood from gpytorch.means import ConstantMean from gpytorch.models import ExactGP d = 10 class SimpleGP(ExactGP): def __init__(self, train_inputs, train_targets): super().__init__(train_inputs, train_targets, GaussianLikelihood()) self.mean_module = ConstantMean() self.covar_module = RBFKernel() def forward(self, x): return MultivariateNormal(self.mean_module(x), self.covar_module(x)) gp = SimpleGP( train_inputs=torch.rand(256, d, dtype=torch.double), train_targets=torch.rand(256, dtype=torch.double), ).eval() gp(torch.rand(5, d, dtype=torch.double)) X = torch.rand(128, 5, d, dtype=torch.double) Y = torch.rand(128, 5, dtype=torch.double) tracemalloc.start() for i in range(200): fantasy_model = gp.get_fantasy_model(inputs=X, targets=Y).eval() with gpt_settings.detach_test_caches(False): fantasy_model(torch.rand(32, d, dtype=torch.double)) if (i + 1) % 50 == 0: current, peak = tracemalloc.get_traced_memory() print(f"Iter {i+1}: current={current/1024/1024:.1f}MB, " f"peak={peak/1024/1024:.1f}MB") Results before fix: memory grows ~50MB+ per 50 iterations, OOM by ~200. Results after fix: memory stays flat at ~1-2 MB across all 200 iterations. Gradient propagation through fantasy model predictions with `detach_test_caches(False)` was also verified to still work correctly. All 59 existing tests in test_exact_gp.py and test_derivative_gp_fantasy.py pass. --- .../models/exact_prediction_strategies.py | 12 +++--------- gpytorch/utils/memoize.py | 19 +++++++++++++++++++ gpytorch/variational/_variational_strategy.py | 7 ++----- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 30d43ebe2..5d8c39bfa 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -25,7 +25,7 @@ from .. import settings from ..distributions import MultitaskMultivariateNormal from ..lazy import LazyEvaluatedKernelTensor -from ..utils.memoize import add_to_cache, cached, clear_cache_hook, pop_from_cache +from ..utils.memoize import add_to_cache, cached, pop_from_cache, register_cache_clear_hook def prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood): @@ -108,10 +108,7 @@ def _exact_predictive_covar_inv_quad_form_cache(self, train_train_covar_inv_root if settings.detach_test_caches.on(): res = res.detach() - if res.grad_fn is not None: - wrapper = functools.partial(clear_cache_hook, self) - functools.update_wrapper(wrapper, clear_cache_hook) - res.grad_fn.register_hook(wrapper) + register_cache_clear_hook(res, self) return res @@ -313,10 +310,7 @@ def _mean_cache(self, nan_policy: str) -> Tensor: if settings.detach_test_caches.on(): mean_cache = mean_cache.detach() - if mean_cache.grad_fn is not None: - wrapper = functools.partial(clear_cache_hook, self) - functools.update_wrapper(wrapper, clear_cache_hook) - mean_cache.grad_fn.register_hook(wrapper) + register_cache_clear_hook(mean_cache, self) return mean_cache diff --git a/gpytorch/utils/memoize.py b/gpytorch/utils/memoize.py index 1a7c1d890..10a1a4dc4 100644 --- a/gpytorch/utils/memoize.py +++ b/gpytorch/utils/memoize.py @@ -4,6 +4,7 @@ import functools import pickle +import weakref from .errors import CachingError @@ -46,6 +47,24 @@ def clear_cache_hook(module, *args, **kwargs): module._memoize_cache = {} +def register_cache_clear_hook(tsr, module): + """Register a backward hook on tsr's grad_fn that clears module's cache. + + Uses a weak reference to module to avoid creating an uncollectable + reference cycle through the C++ grad_fn object (which Python's cycle + GC cannot see through). + """ + if tsr.grad_fn is not None: + weak_module = weakref.ref(module) + + def hook(*args, **kwargs): + obj = weak_module() + if obj is not None: + obj._memoize_cache = {} + + tsr.grad_fn.register_hook(hook) + + def _cached(method=None, name=None): """A decorator allowing for specifying the name of a cache, allowing it to be modified elsewhere. This variant honors the calling args to the decorated function. diff --git a/gpytorch/variational/_variational_strategy.py b/gpytorch/variational/_variational_strategy.py index f157b59f6..f2a859b85 100644 --- a/gpytorch/variational/_variational_strategy.py +++ b/gpytorch/variational/_variational_strategy.py @@ -18,7 +18,7 @@ from ..models import ApproximateGP, ExactGP from ..models.exact_prediction_strategies import DefaultPredictionStrategy from ..module import Module -from ..utils.memoize import add_to_cache, cached, clear_cache_hook +from ..utils.memoize import add_to_cache, cached, clear_cache_hook, register_cache_clear_hook from . import _VariationalDistribution @@ -42,10 +42,7 @@ def forward(self, x: Tensor, **kwargs) -> MultivariateNormal: def _add_cache_hook(tsr: Tensor, pred_strat: DefaultPredictionStrategy) -> Tensor: - if tsr.grad_fn is not None: - wrapper = functools.partial(clear_cache_hook, pred_strat) - functools.update_wrapper(wrapper, clear_cache_hook) - tsr.grad_fn.register_hook(wrapper) + register_cache_clear_hook(tsr, pred_strat) return tsr From 2a38f8c4a852a41438796ce428c57db0bc4e0b51 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 19 Mar 2026 15:51:19 -0400 Subject: [PATCH 2/2] Remove unused functools imports --- gpytorch/models/exact_prediction_strategies.py | 1 - gpytorch/variational/_variational_strategy.py | 1 - 2 files changed, 2 deletions(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 5d8c39bfa..73c83bc20 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import string import warnings diff --git a/gpytorch/variational/_variational_strategy.py b/gpytorch/variational/_variational_strategy.py index f2a859b85..cef6cae29 100644 --- a/gpytorch/variational/_variational_strategy.py +++ b/gpytorch/variational/_variational_strategy.py @@ -2,7 +2,6 @@ from __future__ import annotations -import functools from abc import ABC, abstractproperty from copy import deepcopy