diff --git a/README.md b/README.md
index 25956cf7..ecee838c 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,8 @@
#
BackPACK: Packing more into backprop
-[](https://travis-ci.org/f-dangel/backpack)
+[]()
[](https://coveralls.io/github/f-dangel/backpack)
-[](https://www.python.org/downloads/release/python-370/)
+[](https://www.python.org/downloads/release/python-390/)
BackPACK is built on top of [PyTorch](https://github.com/pytorch/pytorch). It efficiently computes quantities other than the gradient.
diff --git a/backpack/core/derivatives/automatic.py b/backpack/core/derivatives/automatic.py
new file mode 100644
index 00000000..4d35b683
--- /dev/null
+++ b/backpack/core/derivatives/automatic.py
@@ -0,0 +1,186 @@
+"""Automatic derivative implementation via ``torch.autograd``."""
+
+from abc import abstractmethod
+from typing import Dict, List, Optional, Protocol, Tuple, Union
+
+from torch import Tensor, allclose, cat, enable_grad, stack
+from torch.autograd import grad
+from torch.nn import Module, Parameter
+
+from backpack.core.derivatives import shape_check
+from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
+from backpack.utils.subsampling import subsample
+
+
+class ForwardCallable(Protocol):
+ """Type-annotation for functions performing a forward pass."""
+
+ def __call__(
+ self,
+ x: Tensor,
+ *params_args: Union[Parameter, Tensor],
+ **params_kwargs: Union[Parameter, Tensor, None],
+ ) -> Tensor: # noqa: D102
+ pass
+
+
+class AutomaticDerivatives(BaseParameterDerivatives):
+ """Implements derivatives for an arbitrary layer using ``torch.autograd``.
+
+ This class can be used to support new layers without implementing their
+ derivatives. However, this comes at the cost of performance, since the
+ autograd-based implementation is often not as efficient as a hand-crafted one.
+
+ Attributes:
+ BATCH_AXIS: Index of the layer input's batch axis. Default: ``0``.
+ """
+
+ BATCH_AXIS: int = 0
+
+ @staticmethod
+ @abstractmethod
+ def as_functional(module: Module) -> ForwardCallable:
+ """Return a function that performs the layer's forward pass.
+
+ Args:
+ module: Layer for which to return the forward function.
+
+ Returns:
+ Function that performs the forward pass of the layer and returns a tensor
+ representing the result. First argument must be the input tensor, and
+ subsequent keyword arguments must be the layer's parameters.
+
+ Note:
+ One way to automate this procedure would be via
+ ``torch.func.functional_call``. However, this does not work at the moment
+ because the passed layer has hooks. For now, this function must thus
+ be specified explicitly.
+ """
+ raise NotImplementedError("Must be implemented by a child class.")
+
+ @classmethod
+ def forward_pass(
+ cls, module: Module, subsampling: Optional[List[int]] = None
+ ) -> Tuple[Tensor, Dict[str, Tensor], Tensor]:
+ """Perform a forward pass through the layer.
+
+ Args:
+ module: Layer for which to perform the forward pass.
+ subsampling: Indices of the batch axis to keep. If ``None``, all indices
+ are kept.Default: ``None``.
+
+ Returns:
+ The sub-sampled tensor used as input to the forward pass, the parameters,
+ and the output.
+
+ Raises:
+ RuntimeError: If the forward function produces a different output than the
+ layer's forward pass.
+ """
+ # Create an independent copy of the layer's input and parameters
+ input0 = module.input0.clone().detach()
+ input0 = subsample(input0, dim=cls.BATCH_AXIS, subsampling=subsampling)
+ params = {
+ name: param.clone().detach() for name, param in module.named_parameters()
+ }
+
+ # turn on autograd for input and parameters
+ input0.requires_grad_(True)
+ for param in params.values():
+ param.requires_grad_(True)
+
+ forward_fn = cls.as_functional(module)
+ output = forward_fn(input0, **params)
+
+ # make sure the layer's re-created output matches the output from the
+ # initial forward pass
+ if not allclose(
+ output,
+ subsample(module.output, dim=cls.BATCH_AXIS, subsampling=subsampling),
+ ):
+ raise RuntimeError(
+ "Forward function used inside `AutogradDerivatives` produced a "
+ + "different output than the module's forward pass. This indicates "
+ + "1) the layer is non-deterministic and cannot be supported by "
+ + "`AutogradDerivatives`, or 2) `.as_functional` is incorrect."
+ )
+
+ return input0, params, output
+
+ # NOTE Explicitly turn on autodiff as this function is called during a
+ # backward pass.
+ @enable_grad()
+ def _jac_t_mat_prod(
+ self,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ subsampling: Optional[List[int]] = None,
+ ) -> Tensor:
+ # regenerate computation graph for differentiation
+ input0, _, output = self.forward_pass(module, subsampling=subsampling)
+ return grad(output, input0, grad_outputs=mat, is_grads_batched=True)[0]
+
+ # NOTE Explicitly turn on autodiff as this function is called during a
+ # backward pass.
+ @enable_grad()
+ @shape_check.param_mjp_accept_vectors
+ def param_mjp(
+ self,
+ param_str: str,
+ module: Module,
+ g_inp: Tuple[Tensor],
+ g_out: Tuple[Tensor],
+ mat: Tensor,
+ sum_batch: bool = True,
+ subsampling: Optional[List[int]] = None,
+ ) -> Tensor:
+ """Compute matrix-Jacobian products (MJPs) of the module w.r.t. a parameter.
+
+ Handles both vector and matrix inputs. Preserves input format in output.
+
+ Args:
+ param_str: Attribute name under which the parameter is stored in the module.
+ module: Module whose Jacobian will be applied. Must provide access to IO.
+ g_inp: Gradients w.r.t. module input.
+ g_out: Gradients w.r.t. module output.
+ mat: Matrix the Jacobian will be applied to. Has shape
+ ``[V, *module.output.shape]`` (matrix case) or same shape as
+ ``module.output`` (vector case). If used with subsampling, has dimension
+ len(subsampling) instead of batch size along the batch axis.
+ sum_batch: Sum out the MJP's batch axis. Default: ``True``.
+ subsampling: Indices of samples along the output's batch dimension that
+ should be considered. Defaults to ``None`` (use all samples).
+
+ Returns:
+ Matrix-Jacobian products. Has shape ``[V, *param_shape]`` when batch
+ summation is enabled (same shape as parameter in the vector case). Without
+ batch summation, the result has shape ``[V, N, *param_shape]`` (vector case
+ has shape ``[N, *param_shape]``). If used with subsampling, the batch size N
+ is replaced by len(subsampling).
+ """
+ batch_size = module.input0.shape[self.BATCH_AXIS]
+ subsampling = list(range(batch_size)) if subsampling is None else subsampling
+
+ # contains the MJPs for each sample along the batch dimension
+ sample_vjps = []
+
+ for sample_idx, sample in enumerate(subsampling):
+ # regenerate computation graph for differentiation
+ _, params, output = self.forward_pass(module, subsampling=[sample])
+ # shape [V, *module.param_str.shape]
+ vjps = grad(
+ output,
+ params[param_str],
+ grad_outputs=mat[:, [sample_idx]],
+ is_grads_batched=True,
+ )[0]
+ sample_vjps.append(vjps.sum(self.BATCH_AXIS) if sum_batch else vjps)
+
+ # shape [V, B, *module.param_str.shape] or [V, *module.param_str.shape]
+ return (
+ cat(sample_vjps, dim=self.BATCH_AXIS)
+ if sum_batch
+ else stack(sample_vjps, dim=self.BATCH_AXIS + 1)
+ )
diff --git a/backpack/utils/subsampling.py b/backpack/utils/subsampling.py
index 75cb6e73..04d09868 100644
--- a/backpack/utils/subsampling.py
+++ b/backpack/utils/subsampling.py
@@ -1,11 +1,13 @@
"""Utility functions to enable mini-batch subsampling in extensions."""
-from typing import List
+from typing import List, Optional
from torch import Tensor
-def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Tensor:
+def subsample(
+ tensor: Tensor, dim: int = 0, subsampling: Optional[List[int]] = None
+) -> Tensor:
"""Select samples from a tensor along a dimension.
Args:
diff --git a/docs_src/examples/use_cases/example_automatic_support.py b/docs_src/examples/use_cases/example_automatic_support.py
new file mode 100644
index 00000000..23113672
--- /dev/null
+++ b/docs_src/examples/use_cases/example_automatic_support.py
@@ -0,0 +1,324 @@
+"""Automatic support
+====================
+
+This tutorial explains how to support new layers in BackPACK without knowledge
+about derivatives or autodiff internals.
+
+This is possible through a new, and experimental, :class:`AutomaticDerivatives`
+class, which uses PyTorch's :mod:`torch.autograd` under the hood.
+It makes it easy to quickly support new layers. However, this comes at the cost
+of performance, because the autograd-based solution simply cannot avoid internal
+re-computation and for loops.
+
+If you want to support a new layer efficiently, please check out the
+:ref:`Custom module example`.
+
+The automatic support we describe in this tutorial works as follows:
+
+1. Define a derivative class for your layer. All you have to do is specify the forward
+ pass. The derivatives required by BackPACK will be derived from that.
+
+2. Define a module extension for the BackPACK extension you wish to compute, and
+ feed the above derivatives into it.
+
+3. Register the mapping between module and module extension.
+
+We will demonstrate these steps for the group normalization layer
+(:class:`torch.nn.GroupNorm`), which currently has no efficient support in
+BackPACK.
+
+Let's get the imports out of our way.
+
+""" # noqa: B950
+
+from typing import Optional
+
+from torch import Tensor, cuda, device, manual_seed, rand, zeros
+from torch.autograd import grad
+from torch.nn import (
+ Conv2d,
+ Flatten,
+ GroupNorm,
+ Linear,
+ MSELoss,
+ Parameter,
+ ReLU,
+ Sequential,
+ Sigmoid,
+)
+from torch.nn.functional import group_norm
+from torch.nn.utils.convert_parameters import parameters_to_vector
+
+from backpack import backpack, extend, extensions
+from backpack.core.derivatives.automatic import AutomaticDerivatives, ForwardCallable
+from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase
+from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule
+from backpack.hessianfree.ggnvp import ggn_vector_product
+from backpack.utils.convert_parameters import vector_to_parameter_list
+
+# make deterministic
+manual_seed(0)
+
+dev = device("cuda" if cuda.is_available() else "cpu")
+
+# %%
+#
+# Define a derivative class
+# -------------------------
+#
+# The heavy lifting inside BackPACK is abstracted into a class that implements
+# all kinds of derivatives. BackPACK's core provides a class called
+# :class:`AutomaticDerivatives` that can be used to support new layers without
+# implementing their derivatives. The derivatives are simply implemented using
+# :mod:`torch.autograd`. This is less efficient than hand-crafted derivatives, but
+# requires less human time and autodiff expertise.
+#
+#
+# To create a new derivatives class, inherit from :class:`AutomaticDerivatives`
+# and implement its abstract method :func:`as_functional` which returns a
+# function that performs the layer's forward pass.
+#
+# Let's create such a class for the group normalization layer:
+
+
+class GroupNormAutomaticDerivatives(AutomaticDerivatives):
+ """Automatic derivatives for ``torch.nn.GroupNorm``."""
+
+ @staticmethod
+ def as_functional(module: GroupNorm) -> ForwardCallable:
+ """Return the ``GroupNorm`` layer's forward pass function.
+
+ Args:
+ module: The ``GroupNorm`` layer whose forward pass function is returned.
+
+ Returns:
+ The ``GroupNorm`` layer's forward pass function which consumes the layer
+ input and parameters and produces the output.
+ """
+
+ def forward(
+ x: Tensor, weight: Optional[Parameter], bias: Optional[Parameter]
+ ) -> Tensor:
+ """Map layer input and parameters to layer output."""
+ return group_norm(x, module.num_groups, weight, bias, module.eps)
+
+ return forward
+
+
+# %%
+#
+# Define a module extension
+# -------------------------
+#
+# Module extensions in BackPACK define what computations are carried out for a specific
+# layer and extensions.
+#
+# Let's support per-datum gradients (i.e. BackPACK's :class:`BatchGrad
+# ` extension) for the group normalization layer. To that,
+# we have to define a module extension that uses the derivatives class we just created:
+
+
+class BatchGradGroupNormAutomatic(BatchGradBase):
+ """BatchGrad extension for ``torch.nn.GroupNorm`` using automatic derivatives."""
+
+ def __init__(self):
+ """Initialize the extension."""
+ super().__init__(
+ derivatives=GroupNormAutomaticDerivatives(),
+ # ``params`` only needs to be specified if the layer has learnable params
+ params=["weight", "bias"],
+ )
+
+
+# %%
+#
+# Register the module extension
+# -----------------------------
+#
+# To tell BackPACK to execute the above module extension when it encounters a group
+# normalization layer during a backward pass with the
+# :class:`BatchGrad ` the model and loss function,
+# then call the :class:`with backpack(...) ` context manager with
+# the new extension, and finally compare the computation's results.
+
+model = extend(model)
+lossfunc = extend(lossfunc)
+
+# Remember that we registered the mapping between the group normalization layer and
+# our module extension in ``ext`` earlier
+with backpack(ext):
+ loss = lossfunc(model(X), y)
+ loss.backward()
+
+grad_batch = [p.grad_batch for p in params]
+
+# compare
+if len(grad_batch) != len(grad_batch_true):
+ raise AssertionError("Parameter list structure does not match.")
+
+if not all(g.allclose(g_true) for g, g_true in zip(grad_batch, grad_batch_true)):
+ raise AssertionError("Per-datum gradients do not match.")
+else:
+ print("Per-datum gradients match.")
+
+# %%
+#
+# It works!
+#
+# We can now compute per-datum gradients for the parameters of a group norm layer.
+
+# %%
+#
+# Repeat for other extensions
+# ---------------------------
+#
+# So far, we demonstrated everything for one extension. Other extensions follow the same
+# process.
+#
+# We illustrate this here for the exact GGN diagonal
+# (:class:`DiagGGNExact `).
+#
+# Let's re-create the synthetic data and model to avoid side effects from before.
+
+BATCH_SIZE = 10
+X = rand(BATCH_SIZE, 3, 28, 28, device=dev)
+y = rand(BATCH_SIZE, 4, device=dev)
+
+model = Sequential(
+ Conv2d(3, 4, 5, stride=3),
+ ReLU(),
+ GroupNorm(2, 4),
+ Conv2d(4, 2, 3, stride=2),
+ Sigmoid(),
+ Flatten(),
+ Linear(18, 4),
+).to(dev)
+lossfunc = MSELoss()
+
+# %%
+#
+# First, let's compute our ground truth using PyTorch's autodiff.
+
+params = [p for p in model.parameters() if p.requires_grad]
+ggn_dim = sum(p.numel() for p in params)
+diag_ggn_flat = zeros(ggn_dim, device=X.device, dtype=X.dtype)
+
+outputs = model(X)
+loss = lossfunc(outputs, y)
+
+# compute GGN-vector products with all one-hot vectors
+for d in range(ggn_dim):
+ # create unit vector d
+ e_d = zeros(ggn_dim, device=X.device, dtype=X.dtype)
+ e_d[d] = 1.0
+ # convert to list format
+ e_d = vector_to_parameter_list(e_d, params)
+
+ # multiply GGN onto the unit vector -> get back column d of the GGN
+ ggn_e_d = ggn_vector_product(loss, outputs, model, e_d)
+ # flatten
+ ggn_e_d = parameters_to_vector(ggn_e_d)
+
+ # extract the d-th entry (which is on the GGN's diagonal)
+ diag_ggn_flat[d] = ggn_e_d[d]
+
+print(f"Tr(GGN, autograd): {diag_ggn_flat.sum():.3f}")
+
+# %%
+#
+# Next, let's define a module extension that tells
+# :class:`DiagGGNExact ` what to do if it encounters
+# a group normalization layer.
+
+
+class DiagGGNExactGroupNormAutomatic(DiagGGNBaseModule):
+ """GGN diagonal computation for ``torch.nn.GroupNorm`` via automatic derivatives."""
+
+ def __init__(self):
+ """Set up the derivatives."""
+ super().__init__(
+ GroupNormAutomaticDerivatives(), params=["weight", "bias"], sum_batch=True
+ )
+
+
+# %%
+#
+# Register the mapping ...
+
+ext = extensions.DiagGGNExact()
+ext.set_module_extension(GroupNorm, DiagGGNExactGroupNormAutomatic())
+
+# %%
+#
+# ... and run a backward pass with BackPACK
+
+model = extend(model)
+lossfunc = extend(lossfunc)
+
+with backpack(ext):
+ loss = lossfunc(model(X), y)
+ loss.backward()
+
+# %%
+#
+# Finally, let's collect the result and compare it with autograd.
+
+diag_ggn_flat_backpack = parameters_to_vector([p.diag_ggn_exact for p in params])
+print(f"Tr(GGN, BackPACK): {diag_ggn_flat_backpack.sum():.3f}")
+
+if not diag_ggn_flat.allclose(diag_ggn_flat_backpack):
+ raise AssertionError("Exact GGN diagonals do not match.")
+else:
+ print("Exact GGN diagonals match.")
+
+# %%
+#
+# Works as well, great!
diff --git a/fully_documented.txt b/fully_documented.txt
index ebf4bd3b..425b3697 100644
--- a/fully_documented.txt
+++ b/fully_documented.txt
@@ -19,6 +19,7 @@ backpack/core/derivatives/sum_module.py
backpack/core/derivatives/dropout.py
backpack/core/derivatives/slicing.py
backpack/core/derivatives/bcewithlogitsloss.py
+backpack/core/derivatives/automatic.py
backpack/extensions/__init__.py
backpack/extensions/backprop_extension.py
@@ -130,3 +131,6 @@ test/utils/conv_transpose.py
test/custom_module/
test/test_retain_graph.py
test/test_batch_first.py
+test/test_automatic_support.py
+test/automatic_derivatives.py
+test/automatic_extensions.py
diff --git a/test/automatic_derivatives.py b/test/automatic_derivatives.py
new file mode 100644
index 00000000..2d2721d2
--- /dev/null
+++ b/test/automatic_derivatives.py
@@ -0,0 +1,59 @@
+"""Define derivatives of layers computed via autodiff."""
+
+from typing import Callable, Optional
+
+from torch import Tensor
+from torch.nn import Linear, ReLU, Sigmoid
+from torch.nn.functional import linear, relu, sigmoid
+
+from backpack.core.derivatives.automatic import AutomaticDerivatives
+
+
+class LinearAutomaticDerivatives(AutomaticDerivatives):
+ """Automatic derivatives for ``torch.nn.Linear``."""
+
+ @staticmethod
+ def as_functional(
+ module: Linear,
+ ) -> Callable[[Tensor, Tensor, Optional[Tensor]], Tensor]:
+ """Return the linear layer's forward pass function.
+
+ Args:
+ module: A linear layer.
+
+ Returns:
+ The linear layer's forward pass function.
+ """
+ return linear
+
+
+class ReLUAutomaticDerivatives(AutomaticDerivatives):
+ """Automatic derivatives for ``torch.nn.ReLU``."""
+
+ @staticmethod
+ def as_functional(module: ReLU) -> Callable[[Tensor], Tensor]:
+ """Return the ReLU layer's forward pass function.
+
+ Args:
+ module: A ReLU layer.
+
+ Returns:
+ The ReLU layer's forward pass function.
+ """
+ return relu
+
+
+class SigmoidAutomaticDerivatives(AutomaticDerivatives):
+ """Automatic derivatives for ``torch.nn.Sigmoid``."""
+
+ @staticmethod
+ def as_functional(module: Sigmoid) -> Callable[[Tensor], Tensor]:
+ """Return the Sigmoid layer's forward pass function.
+
+ Args:
+ module: A Sigmoid layer.
+
+ Returns:
+ The Sigmoid layer's forward pass function.
+ """
+ return sigmoid
diff --git a/test/automatic_extensions.py b/test/automatic_extensions.py
new file mode 100644
index 00000000..3cc46234
--- /dev/null
+++ b/test/automatic_extensions.py
@@ -0,0 +1,70 @@
+"""Define layer extensions with derivatives based on autodiff."""
+
+from test.automatic_derivatives import (
+ LinearAutomaticDerivatives,
+ ReLUAutomaticDerivatives,
+ SigmoidAutomaticDerivatives,
+)
+
+from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase
+from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule
+
+
+class DiagGGNExactReLUAutomatic(DiagGGNBaseModule):
+ """GGN diagonal computation for ``torch.nn.ReLU`` via automatic derivatives."""
+
+ def __init__(self):
+ """Set up the derivatives."""
+ super().__init__(ReLUAutomaticDerivatives(), sum_batch=True)
+
+
+class DiagGGNExactSigmoidAutomatic(DiagGGNBaseModule):
+ """GGN diagonal computation for ``torch.nn.Sigmoid`` via automatic derivatives."""
+
+ def __init__(self):
+ """Set up the derivatives."""
+ super().__init__(SigmoidAutomaticDerivatives(), sum_batch=True)
+
+
+class DiagGGNExactLinearAutomatic(DiagGGNBaseModule):
+ """GGN diag. computation for ``torch.nn.Linear`` via automatic derivatives."""
+
+ def __init__(self):
+ """Set up the derivatives."""
+ super().__init__(
+ LinearAutomaticDerivatives(), params=["weight", "bias"], sum_batch=True
+ )
+
+
+class BatchDiagGGNExactReLUAutomatic(DiagGGNBaseModule):
+ """GGN diagonal computation for ``torch.nn.ReLU`` via automatic derivatives."""
+
+ def __init__(self):
+ """Set up the derivatives."""
+ super().__init__(ReLUAutomaticDerivatives(), sum_batch=False)
+
+
+class BatchDiagGGNExactSigmoidAutomatic(DiagGGNBaseModule):
+ """GGN diagonal computation for ``torch.nn.Sigmoid`` via automatic derivatives."""
+
+ def __init__(self):
+ """Set up the derivatives."""
+ super().__init__(SigmoidAutomaticDerivatives(), sum_batch=False)
+
+
+class BatchDiagGGNExactLinearAutomatic(DiagGGNBaseModule):
+ """GGN diag. computation for ``torch.nn.Linear`` via automatic derivatives."""
+
+ def __init__(self):
+ """Set up the derivatives."""
+ super().__init__(
+ LinearAutomaticDerivatives(), params=["weight", "bias"], sum_batch=False
+ )
+
+
+class BatchGradLinearAutomatic(BatchGradBase):
+ """Batch gradients for ``torch.nn.Linear`` via automatic derivatives."""
+
+ def __init__(self):
+ """Set up the derivatives."""
+ super().__init__(LinearAutomaticDerivatives(), params=["weight", "bias"])
diff --git a/test/test_automatic_support.py b/test/test_automatic_support.py
new file mode 100644
index 00000000..4a0b7926
--- /dev/null
+++ b/test/test_automatic_support.py
@@ -0,0 +1,121 @@
+"""Test automatic support of new layers."""
+
+from test.automatic_extensions import (
+ BatchDiagGGNExactLinearAutomatic,
+ BatchDiagGGNExactReLUAutomatic,
+ BatchDiagGGNExactSigmoidAutomatic,
+ BatchGradLinearAutomatic,
+ DiagGGNExactLinearAutomatic,
+ DiagGGNExactReLUAutomatic,
+ DiagGGNExactSigmoidAutomatic,
+)
+from test.test___init__ import DEVICES, DEVICES_ID
+from test.utils import popattr
+from typing import List, Union
+
+from pytest import mark, raises
+from torch import allclose, device, manual_seed, rand
+from torch.nn import Linear, MSELoss, ReLU, Sequential, Sigmoid
+
+from backpack import backpack, extend, extensions
+
+
+@mark.parametrize("batched", [False, True], ids=["DiagGGNExact", "BatchDiagGGNExact"])
+@mark.parametrize("dev", DEVICES, ids=DEVICES_ID)
+def test_automatic_support_diag_ggn_exact(dev: device, batched: bool):
+ """Test GGN diagonal computation via automatic derivatives.
+
+ Args:
+ dev: The device on which to run the test.
+ batched: Whether to compute the batched or summed GGN diagonal.
+ """
+ manual_seed(0)
+ X, y = rand(10, 5, device=dev), rand(10, 3, device=dev)
+
+ model = extend(Sequential(Linear(5, 4), ReLU(), Linear(4, 3), Sigmoid()).to(dev))
+ loss_func = extend(MSELoss().to(dev))
+ savefield = "diag_ggn_exact_batch" if batched else "diag_ggn_exact"
+
+ # ground truth
+ ext = extensions.BatchDiagGGNExact() if batched else extensions.DiagGGNExact()
+ with backpack(ext):
+ loss = loss_func(model(X), y)
+ loss.backward()
+ manual = [popattr(p, savefield) for p in model.parameters()]
+
+ # same quantity with automatic support
+ ext = extensions.BatchDiagGGNExact() if batched else extensions.DiagGGNExact()
+ new_mappings = {
+ ReLU: (
+ BatchDiagGGNExactReLUAutomatic() if batched else DiagGGNExactReLUAutomatic()
+ ),
+ Linear: (
+ BatchDiagGGNExactLinearAutomatic()
+ if batched
+ else DiagGGNExactLinearAutomatic()
+ ),
+ Sigmoid: (
+ BatchDiagGGNExactSigmoidAutomatic()
+ if batched
+ else DiagGGNExactSigmoidAutomatic()
+ ),
+ }
+ for layer_cls, extension in new_mappings.items():
+ # make sure we need to turn on explicit overwriting
+ with raises(ValueError):
+ ext.set_module_extension(layer_cls, extension)
+ ext.set_module_extension(layer_cls, extension, overwrite=True)
+
+ with backpack(ext):
+ loss = loss_func(model(X), y)
+ loss.backward()
+ automatic = [popattr(p, savefield) for p in model.parameters()]
+
+ assert len(manual) == len(automatic)
+ for m, a in zip(manual, automatic):
+ assert allclose(m, a)
+
+
+SUBSAMPLINGS = [None, [7, 2, 4]]
+SUBSAMPLING_IDS = [f"subsampling={subsampling}" for subsampling in SUBSAMPLINGS]
+
+
+@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS)
+@mark.parametrize("dev", DEVICES, ids=DEVICES_ID)
+def test_automatic_support_batch_grad(dev: device, subsampling: Union[None, List[int]]):
+ """Test per-example gradient computation via automatic derivatives.
+
+ Args:
+ dev: The device on which to run the test.
+ subsampling: Indices of active samples. ``None`` means full batch.
+ """
+ manual_seed(0)
+ X, y = rand(10, 5, device=dev), rand(10, 3, device=dev)
+
+ model = extend(Sequential(Linear(5, 4), ReLU(), Linear(4, 3), Sigmoid()).to(dev))
+ loss_func = extend(MSELoss().to(dev))
+ savefield = "grad_batch"
+
+ # ground truth
+ with backpack(extensions.BatchGrad(subsampling=subsampling)):
+ loss = loss_func(model(X), y)
+ loss.backward()
+ manual = [popattr(p, savefield) for p in model.parameters()]
+
+ # same quantity with automatic support
+ ext = extensions.BatchGrad(subsampling=subsampling)
+ new_mappings = {Linear: BatchGradLinearAutomatic()}
+ for layer_cls, extension in new_mappings.items():
+ # make sure we need to turn on explicit overwriting
+ with raises(ValueError):
+ ext.set_module_extension(layer_cls, extension)
+ ext.set_module_extension(layer_cls, extension, overwrite=True)
+
+ with backpack(ext):
+ loss = loss_func(model(X), y)
+ loss.backward()
+ automatic = [popattr(p, savefield) for p in model.parameters()]
+
+ assert len(manual) == len(automatic)
+ for m, a in zip(manual, automatic):
+ assert allclose(m, a)
diff --git a/test/utils/__init__.py b/test/utils/__init__.py
index 40711349..c6b0c7f2 100644
--- a/test/utils/__init__.py
+++ b/test/utils/__init__.py
@@ -1,6 +1,6 @@
"""Helper functions for tests."""
-from typing import List
+from typing import Any, List
def chunk_sizes(total_size: int, num_chunks: int) -> List[int]:
@@ -25,3 +25,18 @@ def chunk_sizes(total_size: int, num_chunks: int) -> List[int]:
sizes.append(rest)
return sizes
+
+
+def popattr(obj: Any, name: str) -> Any:
+ """Pop an attribute from an object.
+
+ Args:
+ obj: The object from which to pop the attribute.
+ name: The name of the attribute to pop.
+
+ Returns:
+ The attribute's value.
+ """
+ value = getattr(obj, name)
+ delattr(obj, name)
+ return value