diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 30d43ebe2..73c83bc20 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import string import warnings @@ -25,7 +24,7 @@ from .. import settings from ..distributions import MultitaskMultivariateNormal from ..lazy import LazyEvaluatedKernelTensor -from ..utils.memoize import add_to_cache, cached, clear_cache_hook, pop_from_cache +from ..utils.memoize import add_to_cache, cached, pop_from_cache, register_cache_clear_hook def prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood): @@ -108,10 +107,7 @@ def _exact_predictive_covar_inv_quad_form_cache(self, train_train_covar_inv_root if settings.detach_test_caches.on(): res = res.detach() - if res.grad_fn is not None: - wrapper = functools.partial(clear_cache_hook, self) - functools.update_wrapper(wrapper, clear_cache_hook) - res.grad_fn.register_hook(wrapper) + register_cache_clear_hook(res, self) return res @@ -313,10 +309,7 @@ def _mean_cache(self, nan_policy: str) -> Tensor: if settings.detach_test_caches.on(): mean_cache = mean_cache.detach() - if mean_cache.grad_fn is not None: - wrapper = functools.partial(clear_cache_hook, self) - functools.update_wrapper(wrapper, clear_cache_hook) - mean_cache.grad_fn.register_hook(wrapper) + register_cache_clear_hook(mean_cache, self) return mean_cache diff --git a/gpytorch/utils/memoize.py b/gpytorch/utils/memoize.py index 1a7c1d890..10a1a4dc4 100644 --- a/gpytorch/utils/memoize.py +++ b/gpytorch/utils/memoize.py @@ -4,6 +4,7 @@ import functools import pickle +import weakref from .errors import CachingError @@ -46,6 +47,24 @@ def clear_cache_hook(module, *args, **kwargs): module._memoize_cache = {} +def register_cache_clear_hook(tsr, module): + """Register a backward hook on tsr's grad_fn that clears module's cache. + + Uses a weak reference to module to avoid creating an uncollectable + reference cycle through the C++ grad_fn object (which Python's cycle + GC cannot see through). + """ + if tsr.grad_fn is not None: + weak_module = weakref.ref(module) + + def hook(*args, **kwargs): + obj = weak_module() + if obj is not None: + obj._memoize_cache = {} + + tsr.grad_fn.register_hook(hook) + + def _cached(method=None, name=None): """A decorator allowing for specifying the name of a cache, allowing it to be modified elsewhere. This variant honors the calling args to the decorated function. diff --git a/gpytorch/variational/_variational_strategy.py b/gpytorch/variational/_variational_strategy.py index f157b59f6..cef6cae29 100644 --- a/gpytorch/variational/_variational_strategy.py +++ b/gpytorch/variational/_variational_strategy.py @@ -2,7 +2,6 @@ from __future__ import annotations -import functools from abc import ABC, abstractproperty from copy import deepcopy @@ -18,7 +17,7 @@ from ..models import ApproximateGP, ExactGP from ..models.exact_prediction_strategies import DefaultPredictionStrategy from ..module import Module -from ..utils.memoize import add_to_cache, cached, clear_cache_hook +from ..utils.memoize import add_to_cache, cached, clear_cache_hook, register_cache_clear_hook from . import _VariationalDistribution @@ -42,10 +41,7 @@ def forward(self, x: Tensor, **kwargs) -> MultivariateNormal: def _add_cache_hook(tsr: Tensor, pred_strat: DefaultPredictionStrategy) -> Tensor: - if tsr.grad_fn is not None: - wrapper = functools.partial(clear_cache_hook, pred_strat) - functools.update_wrapper(wrapper, clear_cache_hook) - tsr.grad_fn.register_hook(wrapper) + register_cache_clear_hook(tsr, pred_strat) return tsr