Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/torchjd/_linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._dual_cone import DualConeProjector, QPSolverBased, projector_or_default
from ._generalized_gramian import flatten, movedim, reshape
from ._gramian import compute_gramian, normalize, regularize
from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor
Expand All @@ -15,4 +16,7 @@
"flatten",
"reshape",
"movedim",
"DualConeProjector",
"QPSolverBased",
"projector_or_default",
]
72 changes: 72 additions & 0 deletions src/torchjd/_linalg/_dual_cone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from abc import ABC, abstractmethod
from typing import Literal, TypeAlias

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor

from ._matrix import PSDMatrix


class DualConeProjector(ABC):
@abstractmethod
def project_weights(self, U: Tensor, G: PSDMatrix) -> Tensor:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to __call__?

r"""
Computes the weights `w` of the projection of `J^T u` onto the dual cone of
the rows of `J`, provided `G = J J^T` and `u`. In other words, this computes the `w` that
satisfies `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1].

By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic
program:
minimize v^T G v
subject to u \preceq v

Reference:
[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_.

:param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`.
:param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite.
:return: A tensor of projection weights with the same shape as `U`.
"""


def projector_or_default(projector: DualConeProjector | None) -> DualConeProjector:
if projector is None:
return QPSolverBased("quadprog")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think quadprog should be a subclass of QPSolverBased.

If we don't do that, we'll be unable to use solver-specific extra parameters.

return projector


class QPSolverBased(DualConeProjector):
SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"]

def __init__(self, solver: SUPPORTED_SOLVER) -> None:
self.solver = solver

def __repr__(self) -> str:
return f"QPSolverBased({repr(self.solver)})"

def project_weights(self, U: Tensor, G: Tensor) -> Tensor:

G_ = _to_array(G)
U_ = _to_array(U)

W = np.apply_along_axis(lambda u: self._project_weight_vector(u, G_), axis=-1, arr=U_)

return torch.as_tensor(W, device=G.device, dtype=G.dtype)

def _project_weight_vector(self, u: np.ndarray, G: np.ndarray) -> np.ndarray:

m = G.shape[0]
w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=self.solver)

if w is None: # This may happen when G has large values.
raise ValueError("Failed to solve the quadratic programming problem.")

return w


def _to_array(tensor: Tensor) -> np.ndarray:
"""Transforms a tensor into a numpy array with float64 dtype."""

return tensor.cpu().detach().numpy().astype(np.float64)
33 changes: 23 additions & 10 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from torch import Tensor

from torchjd._linalg import normalize, regularize
from torchjd._linalg import DualConeProjector, normalize, projector_or_default, regularize
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._mixins import _NonDifferentiable
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import _GramianWeighting

Expand All @@ -32,18 +31,18 @@
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
projector: DualConeProjector | None = None,
) -> None:
super().__init__()
self.pref_vector = pref_vector
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver: SUPPORTED_SOLVER = solver
self.projector = projector_or_default(projector)

def forward(self, gramian: PSDMatrix, /) -> Tensor:
u = self.weighting(gramian)
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the regularization and normalization should become part of the projector, because the requiered amount of regularization or projection may vary per solver. Norm_eps and reg_eps should thus also be given to the projector directly I think.

w = project_weights(u, G, self.solver)
w = self.projector.project_weights(u, G)
return w

@property
Expand Down Expand Up @@ -77,6 +76,14 @@

self._reg_eps = value

@property
def projector(self) -> DualConeProjector:
return self._projector

@projector.setter
def projector(self, value: DualConeProjector | None) -> None:
self._projector = projector_or_default(value)


class DualProj(_NonDifferentiable, GramianWeightedAggregator):
r"""
Expand All @@ -102,12 +109,10 @@
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
projector: DualConeProjector | None = None,
) -> None:
self._solver: SUPPORTED_SOLVER = solver

super().__init__(
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, projector=projector),
)

@property
Expand All @@ -134,10 +139,18 @@
def reg_eps(self, value: float) -> None:
self.gramian_weighting.reg_eps = value

@property
def projector(self) -> DualConeProjector:
return self.gramian_weighting.projector

@projector.setter
def projector(self, value: DualConeProjector | None) -> None:
self.gramian_weighting.projector = value

Check warning on line 148 in src/torchjd/aggregation/_dualproj.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/aggregation/_dualproj.py#L148

Added line #L148 was not covered by tests

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps="
f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})"
f"{self.norm_eps}, reg_eps={self.reg_eps}, projector={repr(self.projector)})"
)

def __str__(self) -> str:
Expand Down
33 changes: 23 additions & 10 deletions src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import torch
from torch import Tensor

from torchjd._linalg import normalize, regularize
from torchjd._linalg import DualConeProjector, normalize, projector_or_default, regularize
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._mixins import _NonDifferentiable
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import _GramianWeighting

Expand All @@ -33,18 +32,18 @@
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
projector: DualConeProjector | None = None,
) -> None:
super().__init__()
self.pref_vector = pref_vector
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver: SUPPORTED_SOLVER = solver
self.projector = projector_or_default(projector)

def forward(self, gramian: PSDMatrix, /) -> Tensor:
U = torch.diag(self.weighting(gramian))
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
W = project_weights(U, G, self.solver)
W = self.projector.project_weights(U, G)
return torch.sum(W, dim=0)

@property
Expand Down Expand Up @@ -80,6 +79,14 @@

self._reg_eps = value

@property
def projector(self) -> DualConeProjector:
return self._projector

@projector.setter
def projector(self, value: DualConeProjector | None) -> None:
self._projector = projector_or_default(value)


class UPGrad(_NonDifferentiable, GramianWeightedAggregator):
r"""
Expand All @@ -105,12 +112,10 @@
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
projector: DualConeProjector | None = None,
) -> None:
self._solver: SUPPORTED_SOLVER = solver

super().__init__(
UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, projector=projector),
)

@property
Expand All @@ -137,10 +142,18 @@
def reg_eps(self, value: float) -> None:
self.gramian_weighting.reg_eps = value

@property
def projector(self) -> DualConeProjector:
return self.gramian_weighting.projector

@projector.setter
def projector(self, value: DualConeProjector | None) -> None:
self.gramian_weighting.projector = value

Check warning on line 151 in src/torchjd/aggregation/_upgrad.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/aggregation/_upgrad.py#L151

Added line #L151 was not covered by tests

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps="
f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})"
f"{self.norm_eps}, reg_eps={self.reg_eps}, projector={repr(self.projector)})"
)

def __str__(self) -> str:
Expand Down
62 changes: 0 additions & 62 deletions src/torchjd/aggregation/_utils/dual_cone.py

This file was deleted.

12 changes: 8 additions & 4 deletions tests/unit/aggregation/test_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch import Tensor
from utils.tensors import ones_

from torchjd._linalg import QPSolverBased
from torchjd.aggregation import ConstantWeighting, DualProj
from torchjd.aggregation._dualproj import DualProjWeighting

Expand Down Expand Up @@ -47,21 +48,24 @@ def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None:


def test_representations() -> None:
A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog")
A = DualProj(
pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=QPSolverBased("quadprog")
)
assert (
repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')"
repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector="
"QPSolverBased('quadprog'))"
)
assert str(A) == "DualProj"

A = DualProj(
pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"),
norm_eps=0.0001,
reg_eps=0.0001,
solver="quadprog",
projector=QPSolverBased("quadprog"),
)
assert (
repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, "
"solver='quadprog')"
"projector=QPSolverBased('quadprog'))"
)
assert str(A) == "DualProj([1., 2., 3.])"

Expand Down
1 change: 0 additions & 1 deletion tests/unit/aggregation/test_pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None:
ones_((2,)),
norm_eps=0.0,
reg_eps=0.0,
solver="quadprog",
)

result = pc_grad_weighting(gramian)
Expand Down
Loading
Loading