diff --git a/backpack/hessianfree/hvp.py b/backpack/hessianfree/hvp.py index bc8d14fd..d4b21751 100644 --- a/backpack/hessianfree/hvp.py +++ b/backpack/hessianfree/hvp.py @@ -56,6 +56,12 @@ def hessian_vector_product( ) gv = sum((g_i * v_i).sum() for g_i, v_i in zip(grad_params, v)) - Hv = grad(gv, params, create_graph=True, retain_graph=True, materialize_grads=True) + Hv = grad( + gv, + params, + create_graph=not detach, + retain_graph=True, + materialize_grads=True, + ) return tuple(j.detach() for j in Hv) if detach else Hv diff --git a/backpack/hessianfree/lop.py b/backpack/hessianfree/lop.py index 7d2bd9e3..450e33fc 100644 --- a/backpack/hessianfree/lop.py +++ b/backpack/hessianfree/lop.py @@ -9,7 +9,7 @@ def L_op(ys, xs, ws, retain_graph=True, detach=True): ys, xs, grad_outputs=ws, - create_graph=True, + create_graph=not detach, retain_graph=retain_graph, allow_unused=True, materialize_grads=True, diff --git a/backpack/hessianfree/rop.py b/backpack/hessianfree/rop.py index d2888981..5e0471cd 100644 --- a/backpack/hessianfree/rop.py +++ b/backpack/hessianfree/rop.py @@ -24,7 +24,7 @@ def R_op(ys, xs, vs, retain_graph=True, detach=True): gs, ws, grad_outputs=vs, - create_graph=True, + create_graph=not detach, retain_graph=True, allow_unused=True, materialize_grads=True,