From 299d90c4f2908950599b4092a1e92574ba5c97a2 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Sun, 29 Oct 2023 21:51:33 -0400 Subject: [PATCH 1/9] [ADD] Implement autograd-based derivatives for automatic support --- backpack/core/derivatives/automatic.py | 172 +++++++++++++++++++++ backpack/utils/subsampling.py | 6 +- fully_documented.txt | 2 + test/test_automatic_support.py | 203 +++++++++++++++++++++++++ 4 files changed, 381 insertions(+), 2 deletions(-) create mode 100644 backpack/core/derivatives/automatic.py create mode 100644 test/test_automatic_support.py diff --git a/backpack/core/derivatives/automatic.py b/backpack/core/derivatives/automatic.py new file mode 100644 index 000000000..69eb2a82b --- /dev/null +++ b/backpack/core/derivatives/automatic.py @@ -0,0 +1,172 @@ +"""Automatic derivative implementation via ``torch.autograd``.""" + +from abc import abstractmethod +from typing import Dict, List, Optional, Protocol, Tuple, Union + +from torch import Tensor, allclose, enable_grad, stack +from torch.autograd import grad +from torch.nn import Module, Parameter + +from backpack.core.derivatives import shape_check +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.utils.subsampling import subsample + + +class ForwardCallable(Protocol): + """Type-annotation for functions performing a forward pass.""" + + def __call__( + self, + x: Tensor, + *params_args: Union[Parameter, Tensor], + **params_kwargs: Union[Parameter, Tensor, None] + ) -> Tensor: + ... + + +class AutomaticDerivatives(BaseParameterDerivatives): + """Implements derivatives for an arbitrary layer using ``torch.autograd``. + + This class can be used to support new layers without implementing their + derivatives. However, this comes at the cost of performance, since the + autograd-based implementation is often not as efficient as a hand-crafted one. + + Attributes: + BATCH_AXIS: Index of the layer input's batch axis. Default: ``0``. + """ + + BATCH_AXIS: int = 0 + + @staticmethod + @abstractmethod + def as_functional(module: Module) -> ForwardCallable: + """Return a function that performs the layer's forward pass. + + Args: + module: Layer for which to return the forward function. + + Returns: + Function that performs the forward pass of the layer and returns a tensor + representing the result. First argument must be the input tensor, and + subsequent keyword arguments must be the layer's parameters. + + Note: + One way to automate this procedure would be via + ``torch.func.functional_call``. However, this does not work at the moment + because the passed layer has hooks. For now, this function must thus + be specified explicitly. + """ + raise NotImplementedError("Must be implemented by a child class.") + + @classmethod + def forward_pass( + cls, module: Module, subsampling: Optional[List[int]] = None + ) -> Tuple[Tensor, Dict[str, Tensor], Tensor]: + """Perform a forward pass through the layer. + + Args: + module: Layer for which to perform the forward pass. + subsampling: Indices of the batch axis to keep. If ``None``, all indices + are kept.Default: ``None``. + + Returns: + The sub-sampled tensor used as input to the forward pass, the parameters, + and the output. + """ + # Create an independent copy of the layer's input and parameters + input0 = module.input0.clone().detach() + input0 = subsample(input0, dim=cls.BATCH_AXIS, subsampling=subsampling) + params = { + name: param.clone().detach() for name, param in module.named_parameters() + } + + # turn on autograd for input and parameters + input0.requires_grad_(True) + for param in params.values(): + param.requires_grad_(True) + + forward_fn = cls.as_functional(module) + output = forward_fn(input0, **params) + + # make sure the layer's re-created output matches the output from the + # initial forward pass + if not allclose( + output, + subsample(module.output, dim=cls.BATCH_AXIS, subsampling=subsampling), + ): + raise RuntimeError( + "Forward function used inside `AutogradDerivatives` produced a " + + "different output than the module's forward pass. This indicates " + + "1) the layer is non-deterministic and cannot be supported by " + + "`AutogradDerivatives`, or 2) `.as_functional` is incorrect." + ) + + return input0, params, output + + # NOTE Explicitly turn on autodiff as this function is called during a + # backward pass. + @enable_grad() + def _jac_t_mat_prod( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + subsampling: Optional[List[int]] = None, + ) -> Tensor: + # regenerate computation graph for differentiation + input0, _, output = self.forward_pass(module, subsampling=subsampling) + + # ``mat`` consists of ``V`` vectors of shape ``[*module.output.shape]`` + vjps = [ + grad(output, input0, v, retain_graph=idx != mat.shape[0] - 1)[0] + for idx, v in enumerate(mat) + ] + + return stack(vjps) # shape [V, *module.input0.shape] + + # NOTE Explicitly turn on autodiff as this function is called during a + # backward pass. + @enable_grad() + @shape_check.param_mjp_accept_vectors + def param_mjp( + self, + param_str: str, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mat: Tensor, + sum_batch: bool = True, + subsampling: Optional[List[int]] = None, + ) -> Tensor: + batch_size = module.input0.shape[self.BATCH_AXIS] + subsampling = list(range(batch_size)) if subsampling is None else subsampling + + # contains the MJPs for each sample along the batch dimension + sample_vjps = [] + + # ``mat`` consists of ``V`` vectors of shape ``[*module.output.shape]`` + num_vecs = mat.shape[0] + + for sample_idx, sample in enumerate(subsampling): + # regenerate computation graph for differentiation + _, params, output = self.forward_pass(module, subsampling=[sample]) + + vjps = [ + grad( + output, + params[param_str], + v, + retain_graph=v_idx != num_vecs - 1, + )[0] + for v_idx, v in enumerate(mat[:, [sample_idx]]) + ] + + sample_vjps.append(stack(vjps)) # shape [V, *module.param_str.shape] + + sample_vjps = stack(sample_vjps, dim=1) + + if sum_batch: + sample_vjps = sample_vjps.sum(1) + + return sample_vjps # shape [V, B, *module.param_str.shape] or [V, B, *module.param_str.shape] diff --git a/backpack/utils/subsampling.py b/backpack/utils/subsampling.py index 62d399f4c..0ce67589d 100644 --- a/backpack/utils/subsampling.py +++ b/backpack/utils/subsampling.py @@ -1,10 +1,12 @@ """Utility functions to enable mini-batch subsampling in extensions.""" -from typing import List +from typing import List, Optional from torch import Tensor -def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Tensor: +def subsample( + tensor: Tensor, dim: int = 0, subsampling: Optional[List[int]] = None +) -> Tensor: """Select samples from a tensor along a dimension. Args: diff --git a/fully_documented.txt b/fully_documented.txt index f05271763..6e6e33049 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -21,6 +21,7 @@ backpack/core/derivatives/sum_module.py backpack/core/derivatives/dropout.py backpack/core/derivatives/slicing.py backpack/core/derivatives/bcewithlogitsloss.py +backpack/core/derivatives/automatic.py backpack/extensions/__init__.py backpack/extensions/backprop_extension.py @@ -132,3 +133,4 @@ test/utils/conv_transpose.py test/custom_module/ test/test_retain_graph.py test/test_batch_first.py +test/test_automatic_support.py diff --git a/test/test_automatic_support.py b/test/test_automatic_support.py new file mode 100644 index 000000000..97b99f661 --- /dev/null +++ b/test/test_automatic_support.py @@ -0,0 +1,203 @@ +"""Test automatic support of new layers.""" + +from test.test___init__ import DEVICES, DEVICES_ID +from typing import Callable, List, Optional, Union + +from pytest import mark +from torch import Tensor, allclose, device, manual_seed, rand +from torch.nn import Linear, MSELoss, ReLU, Sequential +from torch.nn.functional import linear, relu + +from backpack import backpack, extend, extensions +from backpack.core.derivatives.automatic import AutomaticDerivatives +from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule + + +class LinearAutomaticDerivatives(AutomaticDerivatives): + """Automatic derivatives for ``torch.nn.Linear.""" + + @staticmethod + def as_functional( + module: Linear, + ) -> Callable[[Tensor, Tensor, Optional[Tensor]], Tensor]: + """Return the linear layer's forward pass function. + + Args: + module: A linear layer. + + Returns: + The linear layer's forward pass function. + """ + return linear + + +class ReLUAutomaticDerivatives(AutomaticDerivatives): + """Automatic derivatives for ``torch.nn.ReLU.""" + + @staticmethod + def as_functional(module: ReLU) -> Callable[[Tensor], Tensor]: + """Return the ReLU layer's forward pass function. + + Args: + module: A ReLU layer. + + Returns: + The ReLU layer's forward pass function. + """ + return relu + + +@mark.parametrize("dev", DEVICES, ids=DEVICES_ID) +def test_automatic_support_diag_ggn_exact(dev: device): + """Test GGN diagonal computation via automatic derivatives. + + Args: + dev: The device on which to run the test. + """ + + class DiagGGNExactReLUAutomatic(DiagGGNBaseModule): + """GGN diagonal computation for ``torch.nn.ReLU`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__(ReLUAutomaticDerivatives(), sum_batch=True) + + class DiagGGNExactLinearAutomatic(DiagGGNBaseModule): + """GGN diag. computation for ``torch.nn.Linear`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__( + LinearAutomaticDerivatives(), params=["weight", "bias"], sum_batch=True + ) + + manual_seed(0) + X, y = rand(10, 5, device=dev), rand(10, 3, device=dev) + + model = extend(Sequential(Linear(5, 4), ReLU(), Linear(4, 3)).to(dev)) + loss_func = extend(MSELoss().to(dev)) + + # ground truth + with backpack(extensions.DiagGGNExact()): + loss = loss_func(model(X), y) + loss.backward() + diag_ggn = [p.diag_ggn_exact for p in model.parameters()] + + # same quantity with automatic support + ext = extensions.DiagGGNExact() + for layer_cls, extension in zip( + [ReLU, Linear], [DiagGGNExactReLUAutomatic(), DiagGGNExactLinearAutomatic()] + ): + ext.set_module_extension(layer_cls, extension, overwrite=True) + + with backpack(ext): + loss = loss_func(model(X), y) + loss.backward() + diag_ggn_automatic = [p.diag_ggn_exact for p in model.parameters()] + + assert len(diag_ggn) == len(diag_ggn_automatic) + for diag, diag_auto in zip(diag_ggn, diag_ggn_automatic): + assert allclose(diag, diag_auto) + + +@mark.parametrize("dev", DEVICES, ids=DEVICES_ID) +def test_automatic_support_batch_diag_ggn_exact(dev: device): + """Test batch GGN diagonal computation via automatic derivatives. + + Args: + dev: The device on which to run the test. + """ + + class BatchDiagGGNExactReLUAutomatic(DiagGGNBaseModule): + """GGN diagonal computation for ``torch.nn.ReLU`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__(ReLUAutomaticDerivatives(), sum_batch=False) + + class BatchDiagGGNExactLinearAutomatic(DiagGGNBaseModule): + """GGN diag. computation for ``torch.nn.Linear`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__( + LinearAutomaticDerivatives(), params=["weight", "bias"], sum_batch=False + ) + + manual_seed(0) + X, y = rand(10, 5, device=dev), rand(10, 3, device=dev) + + model = extend(Sequential(Linear(5, 4), ReLU(), Linear(4, 3)).to(dev)) + loss_func = extend(MSELoss().to(dev)) + + # ground truth + with backpack(extensions.BatchDiagGGNExact()): + loss = loss_func(model(X), y) + loss.backward() + batch_diag_ggn = [p.diag_ggn_exact_batch for p in model.parameters()] + + # same quantity with automatic support + ext = extensions.BatchDiagGGNExact() + for layer_cls, extension in zip( + [ReLU, Linear], + [BatchDiagGGNExactReLUAutomatic(), BatchDiagGGNExactLinearAutomatic()], + ): + ext.set_module_extension(layer_cls, extension, overwrite=True) + + with backpack(ext): + loss = loss_func(model(X), y) + loss.backward() + batch_diag_ggn_automatic = [p.diag_ggn_exact_batch for p in model.parameters()] + + assert len(batch_diag_ggn) == len(batch_diag_ggn_automatic) + for batch_diag, batch_diag_auto in zip(batch_diag_ggn, batch_diag_ggn_automatic): + assert allclose(batch_diag, batch_diag_auto) + + +SUBSAMPLINGS = [None, [7, 2, 4]] +SUBSAMPLING_IDS = [f"subsampling={subsampling}" for subsampling in SUBSAMPLINGS] + + +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +@mark.parametrize("dev", DEVICES, ids=DEVICES_ID) +def test_automatic_support_batch_grad(dev: device, subsampling: Union[None, List[int]]): + """Test per-example gradient computation via automatic derivatives. + + Args: + dev: The device on which to run the test. + subsampling: Indices of active samples. ``None`` means full batch. + """ + + class BatchGradLinearAutomatic(BatchGradBase): + """Batch gradients for ``torch.nn.Linear`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__(LinearAutomaticDerivatives(), params=["weight", "bias"]) + + manual_seed(0) + X, y = rand(10, 5, device=dev), rand(10, 3, device=dev) + + model = extend(Sequential(Linear(5, 4), ReLU(), Linear(4, 3)).to(dev)) + loss_func = extend(MSELoss().to(dev)) + + # ground truth + with backpack(extensions.BatchGrad(subsampling=subsampling)): + loss = loss_func(model(X), y) + loss.backward() + batch_grad = [p.grad_batch for p in model.parameters()] + + # same quantity with automatic support + ext = extensions.BatchGrad(subsampling=subsampling) + for layer_cls, extension in zip([Linear], [BatchGradLinearAutomatic()]): + ext.set_module_extension(layer_cls, extension, overwrite=True) + + with backpack(ext): + loss = loss_func(model(X), y) + loss.backward() + batch_grad_automatic = [p.grad_batch for p in model.parameters()] + + assert len(batch_grad) == len(batch_grad_automatic) + for bg, bg_auto in zip(batch_grad, batch_grad_automatic): + assert allclose(bg, bg_auto) From 4a4b31f4504b6a2d3f40c433e7a6700e54852dc4 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 12 Nov 2024 11:54:46 -0500 Subject: [PATCH 2/9] [REF] Improve module extension tests based on automatic derivatives --- test/automatic_derivatives.py | 59 ++++++++++ test/automatic_extensions.py | 70 ++++++++++++ test/test_automatic_support.py | 196 ++++++++++----------------------- test/utils/__init__.py | 17 ++- 4 files changed, 202 insertions(+), 140 deletions(-) create mode 100644 test/automatic_derivatives.py create mode 100644 test/automatic_extensions.py diff --git a/test/automatic_derivatives.py b/test/automatic_derivatives.py new file mode 100644 index 000000000..2d2721d2d --- /dev/null +++ b/test/automatic_derivatives.py @@ -0,0 +1,59 @@ +"""Define derivatives of layers computed via autodiff.""" + +from typing import Callable, Optional + +from torch import Tensor +from torch.nn import Linear, ReLU, Sigmoid +from torch.nn.functional import linear, relu, sigmoid + +from backpack.core.derivatives.automatic import AutomaticDerivatives + + +class LinearAutomaticDerivatives(AutomaticDerivatives): + """Automatic derivatives for ``torch.nn.Linear``.""" + + @staticmethod + def as_functional( + module: Linear, + ) -> Callable[[Tensor, Tensor, Optional[Tensor]], Tensor]: + """Return the linear layer's forward pass function. + + Args: + module: A linear layer. + + Returns: + The linear layer's forward pass function. + """ + return linear + + +class ReLUAutomaticDerivatives(AutomaticDerivatives): + """Automatic derivatives for ``torch.nn.ReLU``.""" + + @staticmethod + def as_functional(module: ReLU) -> Callable[[Tensor], Tensor]: + """Return the ReLU layer's forward pass function. + + Args: + module: A ReLU layer. + + Returns: + The ReLU layer's forward pass function. + """ + return relu + + +class SigmoidAutomaticDerivatives(AutomaticDerivatives): + """Automatic derivatives for ``torch.nn.Sigmoid``.""" + + @staticmethod + def as_functional(module: Sigmoid) -> Callable[[Tensor], Tensor]: + """Return the Sigmoid layer's forward pass function. + + Args: + module: A Sigmoid layer. + + Returns: + The Sigmoid layer's forward pass function. + """ + return sigmoid diff --git a/test/automatic_extensions.py b/test/automatic_extensions.py new file mode 100644 index 000000000..3cc46234d --- /dev/null +++ b/test/automatic_extensions.py @@ -0,0 +1,70 @@ +"""Define layer extensions with derivatives based on autodiff.""" + +from test.automatic_derivatives import ( + LinearAutomaticDerivatives, + ReLUAutomaticDerivatives, + SigmoidAutomaticDerivatives, +) + +from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule + + +class DiagGGNExactReLUAutomatic(DiagGGNBaseModule): + """GGN diagonal computation for ``torch.nn.ReLU`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__(ReLUAutomaticDerivatives(), sum_batch=True) + + +class DiagGGNExactSigmoidAutomatic(DiagGGNBaseModule): + """GGN diagonal computation for ``torch.nn.Sigmoid`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__(SigmoidAutomaticDerivatives(), sum_batch=True) + + +class DiagGGNExactLinearAutomatic(DiagGGNBaseModule): + """GGN diag. computation for ``torch.nn.Linear`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__( + LinearAutomaticDerivatives(), params=["weight", "bias"], sum_batch=True + ) + + +class BatchDiagGGNExactReLUAutomatic(DiagGGNBaseModule): + """GGN diagonal computation for ``torch.nn.ReLU`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__(ReLUAutomaticDerivatives(), sum_batch=False) + + +class BatchDiagGGNExactSigmoidAutomatic(DiagGGNBaseModule): + """GGN diagonal computation for ``torch.nn.Sigmoid`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__(SigmoidAutomaticDerivatives(), sum_batch=False) + + +class BatchDiagGGNExactLinearAutomatic(DiagGGNBaseModule): + """GGN diag. computation for ``torch.nn.Linear`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__( + LinearAutomaticDerivatives(), params=["weight", "bias"], sum_batch=False + ) + + +class BatchGradLinearAutomatic(BatchGradBase): + """Batch gradients for ``torch.nn.Linear`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__(LinearAutomaticDerivatives(), params=["weight", "bias"]) diff --git a/test/test_automatic_support.py b/test/test_automatic_support.py index 97b99f661..4a0b7926c 100644 --- a/test/test_automatic_support.py +++ b/test/test_automatic_support.py @@ -1,158 +1,79 @@ """Test automatic support of new layers.""" +from test.automatic_extensions import ( + BatchDiagGGNExactLinearAutomatic, + BatchDiagGGNExactReLUAutomatic, + BatchDiagGGNExactSigmoidAutomatic, + BatchGradLinearAutomatic, + DiagGGNExactLinearAutomatic, + DiagGGNExactReLUAutomatic, + DiagGGNExactSigmoidAutomatic, +) from test.test___init__ import DEVICES, DEVICES_ID -from typing import Callable, List, Optional, Union +from test.utils import popattr +from typing import List, Union -from pytest import mark -from torch import Tensor, allclose, device, manual_seed, rand -from torch.nn import Linear, MSELoss, ReLU, Sequential -from torch.nn.functional import linear, relu +from pytest import mark, raises +from torch import allclose, device, manual_seed, rand +from torch.nn import Linear, MSELoss, ReLU, Sequential, Sigmoid from backpack import backpack, extend, extensions -from backpack.core.derivatives.automatic import AutomaticDerivatives -from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase -from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule - - -class LinearAutomaticDerivatives(AutomaticDerivatives): - """Automatic derivatives for ``torch.nn.Linear.""" - - @staticmethod - def as_functional( - module: Linear, - ) -> Callable[[Tensor, Tensor, Optional[Tensor]], Tensor]: - """Return the linear layer's forward pass function. - - Args: - module: A linear layer. - - Returns: - The linear layer's forward pass function. - """ - return linear - - -class ReLUAutomaticDerivatives(AutomaticDerivatives): - """Automatic derivatives for ``torch.nn.ReLU.""" - - @staticmethod - def as_functional(module: ReLU) -> Callable[[Tensor], Tensor]: - """Return the ReLU layer's forward pass function. - - Args: - module: A ReLU layer. - - Returns: - The ReLU layer's forward pass function. - """ - return relu +@mark.parametrize("batched", [False, True], ids=["DiagGGNExact", "BatchDiagGGNExact"]) @mark.parametrize("dev", DEVICES, ids=DEVICES_ID) -def test_automatic_support_diag_ggn_exact(dev: device): +def test_automatic_support_diag_ggn_exact(dev: device, batched: bool): """Test GGN diagonal computation via automatic derivatives. Args: dev: The device on which to run the test. + batched: Whether to compute the batched or summed GGN diagonal. """ - - class DiagGGNExactReLUAutomatic(DiagGGNBaseModule): - """GGN diagonal computation for ``torch.nn.ReLU`` via automatic derivatives.""" - - def __init__(self): - """Set up the derivatives.""" - super().__init__(ReLUAutomaticDerivatives(), sum_batch=True) - - class DiagGGNExactLinearAutomatic(DiagGGNBaseModule): - """GGN diag. computation for ``torch.nn.Linear`` via automatic derivatives.""" - - def __init__(self): - """Set up the derivatives.""" - super().__init__( - LinearAutomaticDerivatives(), params=["weight", "bias"], sum_batch=True - ) - manual_seed(0) X, y = rand(10, 5, device=dev), rand(10, 3, device=dev) - model = extend(Sequential(Linear(5, 4), ReLU(), Linear(4, 3)).to(dev)) + model = extend(Sequential(Linear(5, 4), ReLU(), Linear(4, 3), Sigmoid()).to(dev)) loss_func = extend(MSELoss().to(dev)) + savefield = "diag_ggn_exact_batch" if batched else "diag_ggn_exact" # ground truth - with backpack(extensions.DiagGGNExact()): - loss = loss_func(model(X), y) - loss.backward() - diag_ggn = [p.diag_ggn_exact for p in model.parameters()] - - # same quantity with automatic support - ext = extensions.DiagGGNExact() - for layer_cls, extension in zip( - [ReLU, Linear], [DiagGGNExactReLUAutomatic(), DiagGGNExactLinearAutomatic()] - ): - ext.set_module_extension(layer_cls, extension, overwrite=True) - + ext = extensions.BatchDiagGGNExact() if batched else extensions.DiagGGNExact() with backpack(ext): loss = loss_func(model(X), y) loss.backward() - diag_ggn_automatic = [p.diag_ggn_exact for p in model.parameters()] - - assert len(diag_ggn) == len(diag_ggn_automatic) - for diag, diag_auto in zip(diag_ggn, diag_ggn_automatic): - assert allclose(diag, diag_auto) - - -@mark.parametrize("dev", DEVICES, ids=DEVICES_ID) -def test_automatic_support_batch_diag_ggn_exact(dev: device): - """Test batch GGN diagonal computation via automatic derivatives. - - Args: - dev: The device on which to run the test. - """ - - class BatchDiagGGNExactReLUAutomatic(DiagGGNBaseModule): - """GGN diagonal computation for ``torch.nn.ReLU`` via automatic derivatives.""" - - def __init__(self): - """Set up the derivatives.""" - super().__init__(ReLUAutomaticDerivatives(), sum_batch=False) - - class BatchDiagGGNExactLinearAutomatic(DiagGGNBaseModule): - """GGN diag. computation for ``torch.nn.Linear`` via automatic derivatives.""" - - def __init__(self): - """Set up the derivatives.""" - super().__init__( - LinearAutomaticDerivatives(), params=["weight", "bias"], sum_batch=False - ) - - manual_seed(0) - X, y = rand(10, 5, device=dev), rand(10, 3, device=dev) - - model = extend(Sequential(Linear(5, 4), ReLU(), Linear(4, 3)).to(dev)) - loss_func = extend(MSELoss().to(dev)) - - # ground truth - with backpack(extensions.BatchDiagGGNExact()): - loss = loss_func(model(X), y) - loss.backward() - batch_diag_ggn = [p.diag_ggn_exact_batch for p in model.parameters()] + manual = [popattr(p, savefield) for p in model.parameters()] # same quantity with automatic support - ext = extensions.BatchDiagGGNExact() - for layer_cls, extension in zip( - [ReLU, Linear], - [BatchDiagGGNExactReLUAutomatic(), BatchDiagGGNExactLinearAutomatic()], - ): + ext = extensions.BatchDiagGGNExact() if batched else extensions.DiagGGNExact() + new_mappings = { + ReLU: ( + BatchDiagGGNExactReLUAutomatic() if batched else DiagGGNExactReLUAutomatic() + ), + Linear: ( + BatchDiagGGNExactLinearAutomatic() + if batched + else DiagGGNExactLinearAutomatic() + ), + Sigmoid: ( + BatchDiagGGNExactSigmoidAutomatic() + if batched + else DiagGGNExactSigmoidAutomatic() + ), + } + for layer_cls, extension in new_mappings.items(): + # make sure we need to turn on explicit overwriting + with raises(ValueError): + ext.set_module_extension(layer_cls, extension) ext.set_module_extension(layer_cls, extension, overwrite=True) with backpack(ext): loss = loss_func(model(X), y) loss.backward() - batch_diag_ggn_automatic = [p.diag_ggn_exact_batch for p in model.parameters()] + automatic = [popattr(p, savefield) for p in model.parameters()] - assert len(batch_diag_ggn) == len(batch_diag_ggn_automatic) - for batch_diag, batch_diag_auto in zip(batch_diag_ggn, batch_diag_ggn_automatic): - assert allclose(batch_diag, batch_diag_auto) + assert len(manual) == len(automatic) + for m, a in zip(manual, automatic): + assert allclose(m, a) SUBSAMPLINGS = [None, [7, 2, 4]] @@ -168,36 +89,33 @@ def test_automatic_support_batch_grad(dev: device, subsampling: Union[None, List dev: The device on which to run the test. subsampling: Indices of active samples. ``None`` means full batch. """ - - class BatchGradLinearAutomatic(BatchGradBase): - """Batch gradients for ``torch.nn.Linear`` via automatic derivatives.""" - - def __init__(self): - """Set up the derivatives.""" - super().__init__(LinearAutomaticDerivatives(), params=["weight", "bias"]) - manual_seed(0) X, y = rand(10, 5, device=dev), rand(10, 3, device=dev) - model = extend(Sequential(Linear(5, 4), ReLU(), Linear(4, 3)).to(dev)) + model = extend(Sequential(Linear(5, 4), ReLU(), Linear(4, 3), Sigmoid()).to(dev)) loss_func = extend(MSELoss().to(dev)) + savefield = "grad_batch" # ground truth with backpack(extensions.BatchGrad(subsampling=subsampling)): loss = loss_func(model(X), y) loss.backward() - batch_grad = [p.grad_batch for p in model.parameters()] + manual = [popattr(p, savefield) for p in model.parameters()] # same quantity with automatic support ext = extensions.BatchGrad(subsampling=subsampling) - for layer_cls, extension in zip([Linear], [BatchGradLinearAutomatic()]): + new_mappings = {Linear: BatchGradLinearAutomatic()} + for layer_cls, extension in new_mappings.items(): + # make sure we need to turn on explicit overwriting + with raises(ValueError): + ext.set_module_extension(layer_cls, extension) ext.set_module_extension(layer_cls, extension, overwrite=True) with backpack(ext): loss = loss_func(model(X), y) loss.backward() - batch_grad_automatic = [p.grad_batch for p in model.parameters()] + automatic = [popattr(p, savefield) for p in model.parameters()] - assert len(batch_grad) == len(batch_grad_automatic) - for bg, bg_auto in zip(batch_grad, batch_grad_automatic): - assert allclose(bg, bg_auto) + assert len(manual) == len(automatic) + for m, a in zip(manual, automatic): + assert allclose(m, a) diff --git a/test/utils/__init__.py b/test/utils/__init__.py index 40711349c..c6b0c7f21 100644 --- a/test/utils/__init__.py +++ b/test/utils/__init__.py @@ -1,6 +1,6 @@ """Helper functions for tests.""" -from typing import List +from typing import Any, List def chunk_sizes(total_size: int, num_chunks: int) -> List[int]: @@ -25,3 +25,18 @@ def chunk_sizes(total_size: int, num_chunks: int) -> List[int]: sizes.append(rest) return sizes + + +def popattr(obj: Any, name: str) -> Any: + """Pop an attribute from an object. + + Args: + obj: The object from which to pop the attribute. + name: The name of the attribute to pop. + + Returns: + The attribute's value. + """ + value = getattr(obj, name) + delattr(obj, name) + return value From edbc93345b8ac864084d0602da656bf23fe1fec1 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 12 Nov 2024 16:17:56 -0500 Subject: [PATCH 3/9] [REF] Use `is_grads_batched` to multiply with multiple vectors in parallel --- backpack/core/derivatives/automatic.py | 81 +++++++++++++++----------- 1 file changed, 47 insertions(+), 34 deletions(-) diff --git a/backpack/core/derivatives/automatic.py b/backpack/core/derivatives/automatic.py index 69eb2a82b..b5f5c47b1 100644 --- a/backpack/core/derivatives/automatic.py +++ b/backpack/core/derivatives/automatic.py @@ -3,7 +3,7 @@ from abc import abstractmethod from typing import Dict, List, Optional, Protocol, Tuple, Union -from torch import Tensor, allclose, enable_grad, stack +from torch import Tensor, allclose, cat, enable_grad, stack from torch.autograd import grad from torch.nn import Module, Parameter @@ -19,9 +19,8 @@ def __call__( self, x: Tensor, *params_args: Union[Parameter, Tensor], - **params_kwargs: Union[Parameter, Tensor, None] - ) -> Tensor: - ... + **params_kwargs: Union[Parameter, Tensor, None], + ) -> Tensor: ... # noqa: D102 class AutomaticDerivatives(BaseParameterDerivatives): @@ -72,6 +71,10 @@ def forward_pass( Returns: The sub-sampled tensor used as input to the forward pass, the parameters, and the output. + + Raises: + RuntimeError: If the forward function produces a different output than the + layer's forward pass. """ # Create an independent copy of the layer's input and parameters input0 = module.input0.clone().detach() @@ -116,14 +119,7 @@ def _jac_t_mat_prod( ) -> Tensor: # regenerate computation graph for differentiation input0, _, output = self.forward_pass(module, subsampling=subsampling) - - # ``mat`` consists of ``V`` vectors of shape ``[*module.output.shape]`` - vjps = [ - grad(output, input0, v, retain_graph=idx != mat.shape[0] - 1)[0] - for idx, v in enumerate(mat) - ] - - return stack(vjps) # shape [V, *module.input0.shape] + return grad(output, input0, grad_outputs=mat, is_grads_batched=True)[0] # NOTE Explicitly turn on autodiff as this function is called during a # backward pass. @@ -139,34 +135,51 @@ def param_mjp( sum_batch: bool = True, subsampling: Optional[List[int]] = None, ) -> Tensor: + """Compute matrix-Jacobian products (MJPs) of the module w.r.t. a parameter. + + Handles both vector and matrix inputs. Preserves input format in output. + + Args: + param_str: Attribute name under which the parameter is stored in the module. + module: Module whose Jacobian will be applied. Must provide access to IO. + g_inp: Gradients w.r.t. module input. + g_out: Gradients w.r.t. module output. + mat: Matrix the Jacobian will be applied to. Has shape + ``[V, *module.output.shape]`` (matrix case) or same shape as + ``module.output`` (vector case). If used with subsampling, has dimension + len(subsampling) instead of batch size along the batch axis. + sum_batch: Sum out the MJP's batch axis. Default: ``True``. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). + + Returns: + Matrix-Jacobian products. Has shape ``[V, *param_shape]`` when batch + summation is enabled (same shape as parameter in the vector case). Without + batch summation, the result has shape ``[V, N, *param_shape]`` (vector case + has shape ``[N, *param_shape]``). If used with subsampling, the batch size N + is replaced by len(subsampling). + """ batch_size = module.input0.shape[self.BATCH_AXIS] subsampling = list(range(batch_size)) if subsampling is None else subsampling # contains the MJPs for each sample along the batch dimension sample_vjps = [] - # ``mat`` consists of ``V`` vectors of shape ``[*module.output.shape]`` - num_vecs = mat.shape[0] - for sample_idx, sample in enumerate(subsampling): # regenerate computation graph for differentiation _, params, output = self.forward_pass(module, subsampling=[sample]) - - vjps = [ - grad( - output, - params[param_str], - v, - retain_graph=v_idx != num_vecs - 1, - )[0] - for v_idx, v in enumerate(mat[:, [sample_idx]]) - ] - - sample_vjps.append(stack(vjps)) # shape [V, *module.param_str.shape] - - sample_vjps = stack(sample_vjps, dim=1) - - if sum_batch: - sample_vjps = sample_vjps.sum(1) - - return sample_vjps # shape [V, B, *module.param_str.shape] or [V, B, *module.param_str.shape] + # shape [V, *module.param_str.shape] + vjps = grad( + output, + params[param_str], + grad_outputs=mat[:, [sample_idx]], + is_grads_batched=True, + )[0] + sample_vjps.append(vjps.sum(self.BATCH_AXIS) if sum_batch else vjps) + + # shape [V, B, *module.param_str.shape] or [V, *module.param_str.shape] + return ( + cat(sample_vjps, dim=self.BATCH_AXIS) + if sum_batch + else stack(sample_vjps, dim=self.BATCH_AXIS + 1) + ) From 45f6a4c9c6d10d9442cea2f72f11897ab487c02d Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 12 Nov 2024 16:19:45 -0500 Subject: [PATCH 4/9] [DOC] Update README --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 25956cf77..478dbe03c 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,7 @@ # BackPACK BackPACK: Packing more into backprop -[![Travis](https://travis-ci.org/f-dangel/backpack.svg?branch=master)](https://travis-ci.org/f-dangel/backpack) [![Coveralls](https://coveralls.io/repos/github/f-dangel/backpack/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/backpack) -[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-370/) +[![Python 3.9+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-390/) BackPACK is built on top of [PyTorch](https://github.com/pytorch/pytorch). It efficiently computes quantities other than the gradient. From 0952fb05bee73606e44b9fb9478b8eb20538ae5e Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 12 Nov 2024 16:38:05 -0500 Subject: [PATCH 5/9] [FIX] Linters --- backpack/core/derivatives/automatic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backpack/core/derivatives/automatic.py b/backpack/core/derivatives/automatic.py index b5f5c47b1..4d35b6838 100644 --- a/backpack/core/derivatives/automatic.py +++ b/backpack/core/derivatives/automatic.py @@ -20,7 +20,8 @@ def __call__( x: Tensor, *params_args: Union[Parameter, Tensor], **params_kwargs: Union[Parameter, Tensor, None], - ) -> Tensor: ... # noqa: D102 + ) -> Tensor: # noqa: D102 + pass class AutomaticDerivatives(BaseParameterDerivatives): From 2bc5cea55c81055b4bf3a2fba935880eda5e2791 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 12 Nov 2024 16:53:11 -0500 Subject: [PATCH 6/9] [DOC] Update badges --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 478dbe03c..1ee56bcf1 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,8 @@ # BackPACK BackPACK: Packing more into backprop +[![RTD](https://readthedocs.org/projects/backpack/badge/?version=master)] [![Coveralls](https://coveralls.io/repos/github/f-dangel/backpack/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/backpack) -[![Python 3.9+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-390/) +[![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/release/python-390/) BackPACK is built on top of [PyTorch](https://github.com/pytorch/pytorch). It efficiently computes quantities other than the gradient. From d8fe0bdba216bc96ea442dbf147fe7f29d2d8d51 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 12 Nov 2024 16:54:35 -0500 Subject: [PATCH 7/9] [FIX] Badge --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1ee56bcf1..ecee838cb 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # BackPACK BackPACK: Packing more into backprop -[![RTD](https://readthedocs.org/projects/backpack/badge/?version=master)] +[![RTD](https://readthedocs.org/projects/backpack/badge/?version=master)]() [![Coveralls](https://coveralls.io/repos/github/f-dangel/backpack/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/backpack) [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/release/python-390/) From 73d6f273ea59f5b6c80fc7f8f095ca3fe0a9279b Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 12 Nov 2024 17:09:09 -0500 Subject: [PATCH 8/9] [CI] Add relevant files to `fully_documented.txt` --- fully_documented.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fully_documented.txt b/fully_documented.txt index eca8f4308..425b36971 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -132,3 +132,5 @@ test/custom_module/ test/test_retain_graph.py test/test_batch_first.py test/test_automatic_support.py +test/automatic_derivatives.py +test/automatic_extensions.py From ab1385766356f151991c1794d96ac501dea17912 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Wed, 13 Nov 2024 11:52:19 -0500 Subject: [PATCH 9/9] [DOC] Add tutorial for automatic support of new layers --- .../use_cases/example_automatic_support.py | 324 ++++++++++++++++++ 1 file changed, 324 insertions(+) create mode 100644 docs_src/examples/use_cases/example_automatic_support.py diff --git a/docs_src/examples/use_cases/example_automatic_support.py b/docs_src/examples/use_cases/example_automatic_support.py new file mode 100644 index 000000000..231136728 --- /dev/null +++ b/docs_src/examples/use_cases/example_automatic_support.py @@ -0,0 +1,324 @@ +"""Automatic support +==================== + +This tutorial explains how to support new layers in BackPACK without knowledge +about derivatives or autodiff internals. + +This is possible through a new, and experimental, :class:`AutomaticDerivatives` +class, which uses PyTorch's :mod:`torch.autograd` under the hood. +It makes it easy to quickly support new layers. However, this comes at the cost +of performance, because the autograd-based solution simply cannot avoid internal +re-computation and for loops. + +If you want to support a new layer efficiently, please check out the +:ref:`Custom module example`. + +The automatic support we describe in this tutorial works as follows: + +1. Define a derivative class for your layer. All you have to do is specify the forward + pass. The derivatives required by BackPACK will be derived from that. + +2. Define a module extension for the BackPACK extension you wish to compute, and + feed the above derivatives into it. + +3. Register the mapping between module and module extension. + +We will demonstrate these steps for the group normalization layer +(:class:`torch.nn.GroupNorm`), which currently has no efficient support in +BackPACK. + +Let's get the imports out of our way. + +""" # noqa: B950 + +from typing import Optional + +from torch import Tensor, cuda, device, manual_seed, rand, zeros +from torch.autograd import grad +from torch.nn import ( + Conv2d, + Flatten, + GroupNorm, + Linear, + MSELoss, + Parameter, + ReLU, + Sequential, + Sigmoid, +) +from torch.nn.functional import group_norm +from torch.nn.utils.convert_parameters import parameters_to_vector + +from backpack import backpack, extend, extensions +from backpack.core.derivatives.automatic import AutomaticDerivatives, ForwardCallable +from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase +from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule +from backpack.hessianfree.ggnvp import ggn_vector_product +from backpack.utils.convert_parameters import vector_to_parameter_list + +# make deterministic +manual_seed(0) + +dev = device("cuda" if cuda.is_available() else "cpu") + +# %% +# +# Define a derivative class +# ------------------------- +# +# The heavy lifting inside BackPACK is abstracted into a class that implements +# all kinds of derivatives. BackPACK's core provides a class called +# :class:`AutomaticDerivatives` that can be used to support new layers without +# implementing their derivatives. The derivatives are simply implemented using +# :mod:`torch.autograd`. This is less efficient than hand-crafted derivatives, but +# requires less human time and autodiff expertise. +# +# +# To create a new derivatives class, inherit from :class:`AutomaticDerivatives` +# and implement its abstract method :func:`as_functional` which returns a +# function that performs the layer's forward pass. +# +# Let's create such a class for the group normalization layer: + + +class GroupNormAutomaticDerivatives(AutomaticDerivatives): + """Automatic derivatives for ``torch.nn.GroupNorm``.""" + + @staticmethod + def as_functional(module: GroupNorm) -> ForwardCallable: + """Return the ``GroupNorm`` layer's forward pass function. + + Args: + module: The ``GroupNorm`` layer whose forward pass function is returned. + + Returns: + The ``GroupNorm`` layer's forward pass function which consumes the layer + input and parameters and produces the output. + """ + + def forward( + x: Tensor, weight: Optional[Parameter], bias: Optional[Parameter] + ) -> Tensor: + """Map layer input and parameters to layer output.""" + return group_norm(x, module.num_groups, weight, bias, module.eps) + + return forward + + +# %% +# +# Define a module extension +# ------------------------- +# +# Module extensions in BackPACK define what computations are carried out for a specific +# layer and extensions. +# +# Let's support per-datum gradients (i.e. BackPACK's :class:`BatchGrad +# ` extension) for the group normalization layer. To that, +# we have to define a module extension that uses the derivatives class we just created: + + +class BatchGradGroupNormAutomatic(BatchGradBase): + """BatchGrad extension for ``torch.nn.GroupNorm`` using automatic derivatives.""" + + def __init__(self): + """Initialize the extension.""" + super().__init__( + derivatives=GroupNormAutomaticDerivatives(), + # ``params`` only needs to be specified if the layer has learnable params + params=["weight", "bias"], + ) + + +# %% +# +# Register the module extension +# ----------------------------- +# +# To tell BackPACK to execute the above module extension when it encounters a group +# normalization layer during a backward pass with the +# :class:`BatchGrad ` the model and loss function, +# then call the :class:`with backpack(...) ` context manager with +# the new extension, and finally compare the computation's results. + +model = extend(model) +lossfunc = extend(lossfunc) + +# Remember that we registered the mapping between the group normalization layer and +# our module extension in ``ext`` earlier +with backpack(ext): + loss = lossfunc(model(X), y) + loss.backward() + +grad_batch = [p.grad_batch for p in params] + +# compare +if len(grad_batch) != len(grad_batch_true): + raise AssertionError("Parameter list structure does not match.") + +if not all(g.allclose(g_true) for g, g_true in zip(grad_batch, grad_batch_true)): + raise AssertionError("Per-datum gradients do not match.") +else: + print("Per-datum gradients match.") + +# %% +# +# It works! +# +# We can now compute per-datum gradients for the parameters of a group norm layer. + +# %% +# +# Repeat for other extensions +# --------------------------- +# +# So far, we demonstrated everything for one extension. Other extensions follow the same +# process. +# +# We illustrate this here for the exact GGN diagonal +# (:class:`DiagGGNExact `). +# +# Let's re-create the synthetic data and model to avoid side effects from before. + +BATCH_SIZE = 10 +X = rand(BATCH_SIZE, 3, 28, 28, device=dev) +y = rand(BATCH_SIZE, 4, device=dev) + +model = Sequential( + Conv2d(3, 4, 5, stride=3), + ReLU(), + GroupNorm(2, 4), + Conv2d(4, 2, 3, stride=2), + Sigmoid(), + Flatten(), + Linear(18, 4), +).to(dev) +lossfunc = MSELoss() + +# %% +# +# First, let's compute our ground truth using PyTorch's autodiff. + +params = [p for p in model.parameters() if p.requires_grad] +ggn_dim = sum(p.numel() for p in params) +diag_ggn_flat = zeros(ggn_dim, device=X.device, dtype=X.dtype) + +outputs = model(X) +loss = lossfunc(outputs, y) + +# compute GGN-vector products with all one-hot vectors +for d in range(ggn_dim): + # create unit vector d + e_d = zeros(ggn_dim, device=X.device, dtype=X.dtype) + e_d[d] = 1.0 + # convert to list format + e_d = vector_to_parameter_list(e_d, params) + + # multiply GGN onto the unit vector -> get back column d of the GGN + ggn_e_d = ggn_vector_product(loss, outputs, model, e_d) + # flatten + ggn_e_d = parameters_to_vector(ggn_e_d) + + # extract the d-th entry (which is on the GGN's diagonal) + diag_ggn_flat[d] = ggn_e_d[d] + +print(f"Tr(GGN, autograd): {diag_ggn_flat.sum():.3f}") + +# %% +# +# Next, let's define a module extension that tells +# :class:`DiagGGNExact ` what to do if it encounters +# a group normalization layer. + + +class DiagGGNExactGroupNormAutomatic(DiagGGNBaseModule): + """GGN diagonal computation for ``torch.nn.GroupNorm`` via automatic derivatives.""" + + def __init__(self): + """Set up the derivatives.""" + super().__init__( + GroupNormAutomaticDerivatives(), params=["weight", "bias"], sum_batch=True + ) + + +# %% +# +# Register the mapping ... + +ext = extensions.DiagGGNExact() +ext.set_module_extension(GroupNorm, DiagGGNExactGroupNormAutomatic()) + +# %% +# +# ... and run a backward pass with BackPACK + +model = extend(model) +lossfunc = extend(lossfunc) + +with backpack(ext): + loss = lossfunc(model(X), y) + loss.backward() + +# %% +# +# Finally, let's collect the result and compare it with autograd. + +diag_ggn_flat_backpack = parameters_to_vector([p.diag_ggn_exact for p in params]) +print(f"Tr(GGN, BackPACK): {diag_ggn_flat_backpack.sum():.3f}") + +if not diag_ggn_flat.allclose(diag_ggn_flat_backpack): + raise AssertionError("Exact GGN diagonals do not match.") +else: + print("Exact GGN diagonals match.") + +# %% +# +# Works as well, great!