diff --git a/README.md b/README.md index 25956cf7..ecee838c 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # BackPACK BackPACK: Packing more into backprop -[![Travis](https://travis-ci.org/f-dangel/backpack.svg?branch=master)](https://travis-ci.org/f-dangel/backpack) +[![RTD](https://readthedocs.org/projects/backpack/badge/?version=master)]() [![Coveralls](https://coveralls.io/repos/github/f-dangel/backpack/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/backpack) -[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-370/) +[![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/release/python-390/) BackPACK is built on top of [PyTorch](https://github.com/pytorch/pytorch). It efficiently computes quantities other than the gradient. diff --git a/backpack/core/derivatives/automatic.py b/backpack/core/derivatives/automatic.py new file mode 100644 index 00000000..4d35b683 --- /dev/null +++ b/backpack/core/derivatives/automatic.py @@ -0,0 +1,186 @@ +"""Automatic derivative implementation via ``torch.autograd``.""" + +from abc import abstractmethod +from typing import Dict, List, Optional, Protocol, Tuple, Union + +from torch import Tensor, allclose, cat, enable_grad, stack +from torch.autograd import grad +from torch.nn import Module, Parameter + +from backpack.core.derivatives import shape_check +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.utils.subsampling import subsample + + +class ForwardCallable(Protocol): + """Type-annotation for functions performing a forward pass.""" + + def __call__( + self, + x: Tensor, + *params_args: Union[Parameter, Tensor], + **params_kwargs: Union[Parameter, Tensor, None], + ) -> Tensor: # noqa: D102 + pass + + +class AutomaticDerivatives(BaseParameterDerivatives): + """Implements derivatives for an arbitrary layer using ``torch.autograd``. + + This class can be used to support new layers without implementing their + derivatives. However, this comes at the cost of performance, since the + autograd-based implementation is often not as efficient as a hand-crafted one. + + Attributes: + BATCH_AXIS: Index of the layer input's batch axis. Default: ``0``. + """ + + BATCH_AXIS: int = 0 + + @staticmethod + @abstractmethod + def as_functional(module: Module) -> ForwardCallable: + """Return a function that performs the layer's forward pass. + + Args: + module: Layer for which to return the forward function. + + Returns: + Function that performs the forward pass of the layer and returns a tensor + representing the result. First argument must be the input tensor, and + subsequent keyword arguments must be the layer's parameters. + + Note: + One way to automate this procedure would be via + ``torch.func.functional_call``. However, this does not work at the moment + because the passed layer has hooks. For now, this function must thus + be specified explicitly. + """ + raise NotImplementedError("Must be implemented by a child class.") + + @classmethod + def forward_pass( + cls, module: Module, subsampling: Optional[List[int]] = None + ) -> Tuple[Tensor, Dict[str, Tensor], Tensor]: + """Perform a forward pass through the layer. + + Args: + module: Layer for which to perform the forward pass. + subsampling: Indices of the batch axis to keep. If ``None``, all indices + are kept.Default: ``None``. + + Returns: + The sub-sampled tensor used as input to the forward pass, the parameters, + and the output. + + Raises: + RuntimeError: If the forward function produces a different output than the + layer's forward pass. + """ + # Create an independent copy of the layer's input and parameters + input0 = module.input0.clone().detach() + input0 = subsample(input0, dim=cls.BATCH_AXIS, subsampling=subsampling) + params = { + name: param.clone().detach() for name, param in module.named_parameters() + } + + # turn on autograd for input and parameters + input0.requires_grad_(True) + for param in params.values(): + param.requires_grad_(True) + + forward_fn = cls.as_functional(module) + output = forward_fn(input0, **params) + + # make sure the layer's re-created output matches the output from the + # initial forward pass + if not allclose( + output, + subsample(module.output, dim=cls.BATCH_AXIS, subsampling=subsampling), + ): + raise RuntimeError( + "Forward function used inside `AutogradDerivatives` produced a " + + "different output than the module's forward pass. This indicates " + + "1) the layer is non-deterministic and cannot be supported by " + + "`AutogradDerivatives`, or 2) `.as_functional` is incorrect." + ) + + return input0, params, output + + # NOTE Explicitly turn on autodiff as this function is called during a + # backward pass. + @enable_grad() + def _jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: Optional[List[int]] = None, + ) -> Tensor: + # regenerate computation graph for differentiation + input0, _, output = self.forward_pass(module, subsampling=subsampling) + return grad(output, input0, grad_outputs=mat, is_grads_batched=True)[0] + + # NOTE Explicitly turn on autodiff as this function is called during a + # backward pass. + @enable_grad() + @shape_check.param_mjp_accept_vectors + def param_mjp( + self, + param_str: str, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + subsampling: Optional[List[int]] = None, + ) -> Tensor: + """Compute matrix-Jacobian products (MJPs) of the module w.r.t. a parameter. + + Handles both vector and matrix inputs. Preserves input format in output. + + Args: + param_str: Attribute name under which the parameter is stored in the module. + module: Module whose Jacobian will be applied. Must provide access to IO. + g_inp: Gradients w.r.t. module input. + g_out: Gradients w.r.t. module output. + mat: Matrix the Jacobian will be applied to. Has shape + ``[V, *module.output.shape]`` (matrix case) or same shape as + ``module.output`` (vector case). If used with subsampling, has dimension + len(subsampling) instead of batch size along the batch axis. + sum_batch: Sum out the MJP's batch axis. Default: ``True``. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). + + Returns: + Matrix-Jacobian products. Has shape ``[V, *param_shape]`` when batch + summation is enabled (same shape as parameter in the vector case). Without + batch summation, the result has shape ``[V, N, *param_shape]`` (vector case + has shape ``[N, *param_shape]``). If used with subsampling, the batch size N + is replaced by len(subsampling). + """ + batch_size = module.input0.shape[self.BATCH_AXIS] + subsampling = list(range(batch_size)) if subsampling is None else subsampling + + # contains the MJPs for each sample along the batch dimension + sample_vjps = [] + + for sample_idx, sample in enumerate(subsampling): + # regenerate computation graph for differentiation + _, params, output = self.forward_pass(module, subsampling=[sample]) + # shape [V, *module.param_str.shape] + vjps = grad( + output, + params[param_str], + grad_outputs=mat[:, [sample_idx]], + is_grads_batched=True, + )[0] + sample_vjps.append(vjps.sum(self.BATCH_AXIS) if sum_batch else vjps) + + # shape [V, B, *module.param_str.shape] or [V, *module.param_str.shape] + return ( + cat(sample_vjps, dim=self.BATCH_AXIS) + if sum_batch + else stack(sample_vjps, dim=self.BATCH_AXIS + 1) + ) diff --git a/backpack/utils/subsampling.py b/backpack/utils/subsampling.py index 75cb6e73..04d09868 100644 --- a/backpack/utils/subsampling.py +++ b/backpack/utils/subsampling.py @@ -1,11 +1,13 @@ """Utility functions to enable mini-batch subsampling in extensions.""" -from typing import List +from typing import List, Optional from torch import Tensor -def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Tensor: +def subsample( + tensor: Tensor, dim: int = 0, subsampling: Optional[List[int]] = None +) -> Tensor: """Select samples from a tensor along a dimension. Args: diff --git a/docs_src/examples/use_cases/example_automatic_support.py b/docs_src/examples/use_cases/example_automatic_support.py new file mode 100644 index 00000000..23113672 --- /dev/null +++ b/docs_src/examples/use_cases/example_automatic_support.py @@ -0,0 +1,324 @@ +"""Automatic support +==================== + +This tutorial explains how to support new layers in BackPACK without knowledge +about derivatives or autodiff internals. + +This is possible through a new, and experimental, :class:`AutomaticDerivatives` +class, which uses PyTorch's :mod:`torch.autograd` under the hood. +It makes it easy to quickly support new layers. However, this comes at the cost +of performance, because the autograd-based solution simply cannot avoid internal +re-computation and for loops. + +If you want to support a new layer efficiently, please check out the +:ref:`Custom module example`. + +The automatic support we describe in this tutorial works as follows: + +1. Define a derivative class for your layer. All you have to do is specify the forward + pass. The derivatives required by BackPACK will be derived from that. + +2. Define a module extension for the BackPACK extension you wish to compute, and + feed the above derivatives into it. + +3. Register the mapping between module and module extension. + +We will demonstrate these steps for the group normalization layer +(:class:`torch.nn.GroupNorm`), which currently has no efficient support in +BackPACK. + +Let's get the imports out of our way. + +""" # noqa: B950 + +from typing import Optional + +from torch import Tensor, cuda, device, manual_seed, rand, zeros +from torch.autograd import grad +from torch.nn import ( + Conv2d, + Flatten, + GroupNorm, + Linear, + MSELoss, + Parameter, + ReLU, + Sequential, + Sigmoid, +) +from torch.nn.functional import group_norm +from torch.nn.utils.convert_parameters import parameters_to_vector + +from backpack import backpack, extend, extensions +from backpack.core.derivatives.automatic import AutomaticDerivatives, ForwardCallable +from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule +from backpack.hessianfree.ggnvp import ggn_vector_product +from backpack.utils.convert_parameters import vector_to_parameter_list + +# make deterministic +manual_seed(0) + +dev = device("cuda" if cuda.is_available() else "cpu") + +# %% +# +# Define a derivative class +# ------------------------- +# +# The heavy lifting inside BackPACK is abstracted into a class that implements +# all kinds of derivatives. BackPACK's core provides a class called +# :class:`AutomaticDerivatives` that can be used to support new layers without +# implementing their derivatives. The derivatives are simply implemented using +# :mod:`torch.autograd`. This is less efficient than hand-crafted derivatives, but +# requires less human time and autodiff expertise. +# +# +# To create a new derivatives class, inherit from :class:`AutomaticDerivatives` +# and implement its abstract method :func:`as_functional` which returns a +# function that performs the layer's forward pass. +# +# Let's create such a class for the group normalization layer: + + +class GroupNormAutomaticDerivatives(AutomaticDerivatives): + """Automatic derivatives for ``torch.nn.GroupNorm``.""" + + @staticmethod + def as_functional(module: GroupNorm) -> ForwardCallable: + """Return the ``GroupNorm`` layer's forward pass function. + + Args: + module: The ``GroupNorm`` layer whose forward pass function is returned. + + Returns: + The ``GroupNorm`` layer's forward pass function which consumes the layer + input and parameters and produces the output. + """ + + def forward( + x: Tensor, weight: Optional[Parameter], bias: Optional[Parameter] + ) -> Tensor: + """Map layer input and parameters to layer output.""" + return group_norm(x, module.num_groups, weight, bias, module.eps) + + return forward + + +# %% +# +# Define a module extension +# ------------------------- +# +# Module extensions in BackPACK define what computations are carried out for a specific +# layer and extensions. +# +# Let's support per-datum gradients (i.e. BackPACK's :class:`BatchGrad +# ` extension) for the group normalization layer. To that, +# we have to define a module extension that uses the derivatives class we just created: + + +class BatchGradGroupNormAutomatic(BatchGradBase): + """BatchGrad extension for ``torch.nn.GroupNorm`` using automatic derivatives.""" + + def __init__(self): + """Initialize the extension.""" + super().__init__( + derivatives=GroupNormAutomaticDerivatives(), + # ``params`` only needs to be specified if the layer has learnable params + params=["weight", "bias"], + ) + + +# %% +# +# Register the module extension +# ----------------------------- +# +# To tell BackPACK to execute the above module extension when it encounters a group +# normalization layer during a backward pass with the +# :class:`BatchGrad ` the model and loss function, +# then call the :class:`with backpack(...) ` context manager with +# the new extension, and finally compare the computation's results. + +model = extend(model) +lossfunc = extend(lossfunc) + +# Remember that we registered the mapping between the group normalization layer and +# our module extension in ``ext`` earlier +with backpack(ext): + loss = lossfunc(model(X), y) + loss.backward() + +grad_batch = [p.grad_batch for p in params] + +# compare +if len(grad_batch) != len(grad_batch_true): + raise AssertionError("Parameter list structure does not match.") + +if not all(g.allclose(g_true) for g, g_true in zip(grad_batch, grad_batch_true)): + raise AssertionError("Per-datum gradients do not match.") +else: + print("Per-datum gradients match.") + +# %% +# +# It works! +# +# We can now compute per-datum gradients for the parameters of a group norm layer. + +# %% +# +# Repeat for other extensions +# --------------------------- +# +# So far, we demonstrated everything for one extension. Other extensions follow the same +# process. +# +# We illustrate this here for the exact GGN diagonal +# (:class:`DiagGGNExact `). +# +# Let's re-create the synthetic data and model to avoid side effects from before. + +BATCH_SIZE = 10 +X = rand(BATCH_SIZE, 3, 28, 28, device=dev) +y = rand(BATCH_SIZE, 4, device=dev) + +model = Sequential( + Conv2d(3, 4, 5, stride=3), + ReLU(), + GroupNorm(2, 4), + Conv2d(4, 2, 3, stride=2), + Sigmoid(), + Flatten(), + Linear(18, 4), +).to(dev) +lossfunc = MSELoss() + +# %% +# +# First, let's compute our ground truth using PyTorch's autodiff. + +params = [p for p in model.parameters() if p.requires_grad] +ggn_dim = sum(p.numel() for p in params) +diag_ggn_flat = zeros(ggn_dim, device=X.device, dtype=X.dtype) + +outputs = model(X) +loss = lossfunc(outputs, y) + +# compute GGN-vector products with all one-hot vectors +for d in range(ggn_dim): + # create unit vector d + e_d = zeros(ggn_dim, device=X.device, dtype=X.dtype) + e_d[d] = 1.0 + # convert to list format + e_d = vector_to_parameter_list(e_d, params) + + # multiply GGN onto the unit vector -> get back column d of the GGN + ggn_e_d = ggn_vector_product(loss, outputs, model, e_d) + # flatten + ggn_e_d = parameters_to_vector(ggn_e_d) + + # extract the d-th entry (which is on the GGN's diagonal) + diag_ggn_flat[d] = ggn_e_d[d] + +print(f"Tr(GGN, autograd): {diag_ggn_flat.sum():.3f}") + +# %% +# +# Next, let's define a module extension that tells +# :class:`DiagGGNExact ` what to do if it encounters +# a group normalization layer. + + +class DiagGGNExactGroupNormAutomatic(DiagGGNBaseModule): + """GGN diagonal computation for ``torch.nn.GroupNorm`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__( + GroupNormAutomaticDerivatives(), params=["weight", "bias"], sum_batch=True + ) + + +# %% +# +# Register the mapping ... + +ext = extensions.DiagGGNExact() +ext.set_module_extension(GroupNorm, DiagGGNExactGroupNormAutomatic()) + +# %% +# +# ... and run a backward pass with BackPACK + +model = extend(model) +lossfunc = extend(lossfunc) + +with backpack(ext): + loss = lossfunc(model(X), y) + loss.backward() + +# %% +# +# Finally, let's collect the result and compare it with autograd. + +diag_ggn_flat_backpack = parameters_to_vector([p.diag_ggn_exact for p in params]) +print(f"Tr(GGN, BackPACK): {diag_ggn_flat_backpack.sum():.3f}") + +if not diag_ggn_flat.allclose(diag_ggn_flat_backpack): + raise AssertionError("Exact GGN diagonals do not match.") +else: + print("Exact GGN diagonals match.") + +# %% +# +# Works as well, great! diff --git a/fully_documented.txt b/fully_documented.txt index ebf4bd3b..425b3697 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -19,6 +19,7 @@ backpack/core/derivatives/sum_module.py backpack/core/derivatives/dropout.py backpack/core/derivatives/slicing.py backpack/core/derivatives/bcewithlogitsloss.py +backpack/core/derivatives/automatic.py backpack/extensions/__init__.py backpack/extensions/backprop_extension.py @@ -130,3 +131,6 @@ test/utils/conv_transpose.py test/custom_module/ test/test_retain_graph.py test/test_batch_first.py +test/test_automatic_support.py +test/automatic_derivatives.py +test/automatic_extensions.py diff --git a/test/automatic_derivatives.py b/test/automatic_derivatives.py new file mode 100644 index 00000000..2d2721d2 --- /dev/null +++ b/test/automatic_derivatives.py @@ -0,0 +1,59 @@ +"""Define derivatives of layers computed via autodiff.""" + +from typing import Callable, Optional + +from torch import Tensor +from torch.nn import Linear, ReLU, Sigmoid +from torch.nn.functional import linear, relu, sigmoid + +from backpack.core.derivatives.automatic import AutomaticDerivatives + + +class LinearAutomaticDerivatives(AutomaticDerivatives): + """Automatic derivatives for ``torch.nn.Linear``.""" + + @staticmethod + def as_functional( + module: Linear, + ) -> Callable[[Tensor, Tensor, Optional[Tensor]], Tensor]: + """Return the linear layer's forward pass function. + + Args: + module: A linear layer. + + Returns: + The linear layer's forward pass function. + """ + return linear + + +class ReLUAutomaticDerivatives(AutomaticDerivatives): + """Automatic derivatives for ``torch.nn.ReLU``.""" + + @staticmethod + def as_functional(module: ReLU) -> Callable[[Tensor], Tensor]: + """Return the ReLU layer's forward pass function. + + Args: + module: A ReLU layer. + + Returns: + The ReLU layer's forward pass function. + """ + return relu + + +class SigmoidAutomaticDerivatives(AutomaticDerivatives): + """Automatic derivatives for ``torch.nn.Sigmoid``.""" + + @staticmethod + def as_functional(module: Sigmoid) -> Callable[[Tensor], Tensor]: + """Return the Sigmoid layer's forward pass function. + + Args: + module: A Sigmoid layer. + + Returns: + The Sigmoid layer's forward pass function. + """ + return sigmoid diff --git a/test/automatic_extensions.py b/test/automatic_extensions.py new file mode 100644 index 00000000..3cc46234 --- /dev/null +++ b/test/automatic_extensions.py @@ -0,0 +1,70 @@ +"""Define layer extensions with derivatives based on autodiff.""" + +from test.automatic_derivatives import ( + LinearAutomaticDerivatives, + ReLUAutomaticDerivatives, + SigmoidAutomaticDerivatives, +) + +from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule + + +class DiagGGNExactReLUAutomatic(DiagGGNBaseModule): + """GGN diagonal computation for ``torch.nn.ReLU`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__(ReLUAutomaticDerivatives(), sum_batch=True) + + +class DiagGGNExactSigmoidAutomatic(DiagGGNBaseModule): + """GGN diagonal computation for ``torch.nn.Sigmoid`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__(SigmoidAutomaticDerivatives(), sum_batch=True) + + +class DiagGGNExactLinearAutomatic(DiagGGNBaseModule): + """GGN diag. computation for ``torch.nn.Linear`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__( + LinearAutomaticDerivatives(), params=["weight", "bias"], sum_batch=True + ) + + +class BatchDiagGGNExactReLUAutomatic(DiagGGNBaseModule): + """GGN diagonal computation for ``torch.nn.ReLU`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__(ReLUAutomaticDerivatives(), sum_batch=False) + + +class BatchDiagGGNExactSigmoidAutomatic(DiagGGNBaseModule): + """GGN diagonal computation for ``torch.nn.Sigmoid`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__(SigmoidAutomaticDerivatives(), sum_batch=False) + + +class BatchDiagGGNExactLinearAutomatic(DiagGGNBaseModule): + """GGN diag. computation for ``torch.nn.Linear`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__( + LinearAutomaticDerivatives(), params=["weight", "bias"], sum_batch=False + ) + + +class BatchGradLinearAutomatic(BatchGradBase): + """Batch gradients for ``torch.nn.Linear`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__(LinearAutomaticDerivatives(), params=["weight", "bias"]) diff --git a/test/test_automatic_support.py b/test/test_automatic_support.py new file mode 100644 index 00000000..4a0b7926 --- /dev/null +++ b/test/test_automatic_support.py @@ -0,0 +1,121 @@ +"""Test automatic support of new layers.""" + +from test.automatic_extensions import ( + BatchDiagGGNExactLinearAutomatic, + BatchDiagGGNExactReLUAutomatic, + BatchDiagGGNExactSigmoidAutomatic, + BatchGradLinearAutomatic, + DiagGGNExactLinearAutomatic, + DiagGGNExactReLUAutomatic, + DiagGGNExactSigmoidAutomatic, +) +from test.test___init__ import DEVICES, DEVICES_ID +from test.utils import popattr +from typing import List, Union + +from pytest import mark, raises +from torch import allclose, device, manual_seed, rand +from torch.nn import Linear, MSELoss, ReLU, Sequential, Sigmoid + +from backpack import backpack, extend, extensions + + +@mark.parametrize("batched", [False, True], ids=["DiagGGNExact", "BatchDiagGGNExact"]) +@mark.parametrize("dev", DEVICES, ids=DEVICES_ID) +def test_automatic_support_diag_ggn_exact(dev: device, batched: bool): + """Test GGN diagonal computation via automatic derivatives. + + Args: + dev: The device on which to run the test. + batched: Whether to compute the batched or summed GGN diagonal. + """ + manual_seed(0) + X, y = rand(10, 5, device=dev), rand(10, 3, device=dev) + + model = extend(Sequential(Linear(5, 4), ReLU(), Linear(4, 3), Sigmoid()).to(dev)) + loss_func = extend(MSELoss().to(dev)) + savefield = "diag_ggn_exact_batch" if batched else "diag_ggn_exact" + + # ground truth + ext = extensions.BatchDiagGGNExact() if batched else extensions.DiagGGNExact() + with backpack(ext): + loss = loss_func(model(X), y) + loss.backward() + manual = [popattr(p, savefield) for p in model.parameters()] + + # same quantity with automatic support + ext = extensions.BatchDiagGGNExact() if batched else extensions.DiagGGNExact() + new_mappings = { + ReLU: ( + BatchDiagGGNExactReLUAutomatic() if batched else DiagGGNExactReLUAutomatic() + ), + Linear: ( + BatchDiagGGNExactLinearAutomatic() + if batched + else DiagGGNExactLinearAutomatic() + ), + Sigmoid: ( + BatchDiagGGNExactSigmoidAutomatic() + if batched + else DiagGGNExactSigmoidAutomatic() + ), + } + for layer_cls, extension in new_mappings.items(): + # make sure we need to turn on explicit overwriting + with raises(ValueError): + ext.set_module_extension(layer_cls, extension) + ext.set_module_extension(layer_cls, extension, overwrite=True) + + with backpack(ext): + loss = loss_func(model(X), y) + loss.backward() + automatic = [popattr(p, savefield) for p in model.parameters()] + + assert len(manual) == len(automatic) + for m, a in zip(manual, automatic): + assert allclose(m, a) + + +SUBSAMPLINGS = [None, [7, 2, 4]] +SUBSAMPLING_IDS = [f"subsampling={subsampling}" for subsampling in SUBSAMPLINGS] + + +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +@mark.parametrize("dev", DEVICES, ids=DEVICES_ID) +def test_automatic_support_batch_grad(dev: device, subsampling: Union[None, List[int]]): + """Test per-example gradient computation via automatic derivatives. + + Args: + dev: The device on which to run the test. + subsampling: Indices of active samples. ``None`` means full batch. + """ + manual_seed(0) + X, y = rand(10, 5, device=dev), rand(10, 3, device=dev) + + model = extend(Sequential(Linear(5, 4), ReLU(), Linear(4, 3), Sigmoid()).to(dev)) + loss_func = extend(MSELoss().to(dev)) + savefield = "grad_batch" + + # ground truth + with backpack(extensions.BatchGrad(subsampling=subsampling)): + loss = loss_func(model(X), y) + loss.backward() + manual = [popattr(p, savefield) for p in model.parameters()] + + # same quantity with automatic support + ext = extensions.BatchGrad(subsampling=subsampling) + new_mappings = {Linear: BatchGradLinearAutomatic()} + for layer_cls, extension in new_mappings.items(): + # make sure we need to turn on explicit overwriting + with raises(ValueError): + ext.set_module_extension(layer_cls, extension) + ext.set_module_extension(layer_cls, extension, overwrite=True) + + with backpack(ext): + loss = loss_func(model(X), y) + loss.backward() + automatic = [popattr(p, savefield) for p in model.parameters()] + + assert len(manual) == len(automatic) + for m, a in zip(manual, automatic): + assert allclose(m, a) diff --git a/test/utils/__init__.py b/test/utils/__init__.py index 40711349..c6b0c7f2 100644 --- a/test/utils/__init__.py +++ b/test/utils/__init__.py @@ -1,6 +1,6 @@ """Helper functions for tests.""" -from typing import List +from typing import Any, List def chunk_sizes(total_size: int, num_chunks: int) -> List[int]: @@ -25,3 +25,18 @@ def chunk_sizes(total_size: int, num_chunks: int) -> List[int]: sizes.append(rest) return sizes + + +def popattr(obj: Any, name: str) -> Any: + """Pop an attribute from an object. + + Args: + obj: The object from which to pop the attribute. + name: The name of the attribute to pop. + + Returns: + The attribute's value. + """ + value = getattr(obj, name) + delattr(obj, name) + return value