From a07b9ec7d766f9412de734b35bea0689b825c865 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Wed, 10 Jun 2026 02:03:04 +0300 Subject: [PATCH 1/8] jit the whole get_tangents process for ProximalProjection --- desc/optimize/_constraint_wrappers.py | 214 +++++++++++++++++--------- 1 file changed, 141 insertions(+), 73 deletions(-) diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 1041b2a1fb..e684fddc44 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -752,9 +752,11 @@ def build(self, use_jit=None, verbose=1): # noqa: C901 # we remove the R_lmn, Z_lmn, L_lmn, Ra_n, Za_n from the equilibrium params # dimc_per_thing accounts for that, don't confuse it with reduced state vector self._dimc_per_thing = [t.dim_x for t in self.things] - self._dimc_per_thing[self._eq_idx] = np.sum( - [self._eq.dimensions[arg] for arg in self._args] + self._dimc_per_thing[self._eq_idx] = int( + np.sum([self._eq.dimensions[arg] for arg in self._args]) ) + # we will need to set this static attribute, only possible if tuple + self._dimc_per_thing = tuple(self._dimc_per_thing) # equivalent matrix for A[unfixed_idx] @ D @ Z == A @ feasible_tangents self._feasible_tangents = jnp.eye(self._objective.dim_x) @@ -1063,16 +1065,22 @@ def grad(self, x, constants=None): # where ∇G is the Jacobian of G with respect to full state vector # and ∇F is the Jacobian of F with respect to full state vector. Then, # ∇L = G.T @ ∇G @ [dc_tangents - (∇F @ dx_tangents) ^ -1 @ (∇F @ dc_tangents)] - # We get the part in [] using the _get_tangent method. + # We get the part in [] using the _proximal_get_tangents. v = jnp.eye(x.shape[0]) constants = setdefault(constants, [None, None]) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._get_tangent(u, xf, constants, op="scaled_error") - tangents = batched_vectorize( - jvpfun, - signature="(n)->(k)", - chunk_size=self._constraint._jac_chunk_size, - )(v) + tangents = _proximal_get_tangents( + self._constraint, + xf, + v, + constants[1], + self._eq_solve_objective._feasible_tangents.T, + self._dxdc, + self._feasible_tangents, + self._dimc_per_thing, + self._eq_idx, + "scaled_error", + ) g = self._objective.compute_scaled_error(xg, constants[0]) g_vjp = self._objective.vjp_scaled_error(g, xg, constants[0]) return tangents @ g_vjp @@ -1228,12 +1236,18 @@ def _jvp(self, v, x, constants=None, op="scaled_error"): # we don't need to divide this part into blocked and batched because # self._constraint._deriv_mode will handle it - jvpfun = lambda u: self._get_tangent(u, xf, constants, op=op) - tangents = batched_vectorize( - jvpfun, - signature="(n)->(k)", - chunk_size=self._constraint._jac_chunk_size, - )(v) + tangents = _proximal_get_tangents( + self._constraint, + xf, + v, + constants[1], + self._eq_solve_objective._feasible_tangents.T, + self._dxdc, + self._feasible_tangents, + self._dimc_per_thing, + self._eq_idx, + op, + ) if self._objective._deriv_mode == "batched": # objective's method already know about its jac_chunk_size @@ -1246,53 +1260,6 @@ def _jvp(self, v, x, constants=None, op="scaled_error"): op, ) - def _get_tangent(self, v, xf, constants, op): - # Note: This function is vectorized over v. So, v is expected to be 1D array - # of size self.dim_x. - - # v contains self._args DoFs from eq and other objects (like coils, surfaces - # etc), we want jvp_f to only get parts from equilibrium, not other things - vs = jnp.split(v, np.cumsum(self._dimc_per_thing)) - # This is (dF/dx)^-1 * dF/dc # noqa : E800 - dfdc = _proximal_jvp_f_pure( - self._constraint, - xf, - constants[1], - vs[self._eq_idx], - self._eq_solve_objective._feasible_tangents, - self._dxdc, - op, - ) - # broadcasting against multiple things - dfdcs = [jnp.zeros(dim) for dim in self._dimc_per_thing] - dfdcs[self._eq_idx] = dfdc - # note that dfdc.size != vs[self._eq_idx].size - # dfdc has the size of reduced state vector of the equilibrium - # but vs[self._eq_idx] has the size of self._args DoFs - dfdc = jnp.concatenate(dfdcs) - - # We try to find dG/dc - dG/dx * (dF/dx)^-1 * dF/dc - # where G is the objective function. Since DESC stores x and c in the same - # vector, instead of multiple JVP calls, we will just find a tangent direction - # that will give us the same result. - # For making the explanation clear, assume J is the Jacobian of the objective - # function with respect to the full state vector (both x and c). Then, - # dG/dc = J @ (tangent vectors in c direction) - # dG/dx = J @ (tangent vectors in x direction) - # So, dG/dc - dG/dx * (dF/dx)^-1 * dF/dc can be written as - # J @ [(tangent vectors in c direction) - (tangent vectors in x direction)@dfdc] - # Note: We will never form full Jacobian J, we will just compute the above - # expression by JVPs. - dxdcv = jnp.concatenate( - [ - *vs[: self._eq_idx], - self._dxdc @ vs[self._eq_idx], # Rb_lmn, Zb_lmn to full eq state vector - *vs[self._eq_idx + 1 :], - ] - ) - tangent = dxdcv - self._feasible_tangents @ dfdc - return tangent - @property def constants(self): """list: constant parameters for each sub-objective.""" @@ -1318,9 +1285,9 @@ def __getattr__(self, name): # define these helper functions that are stateless so we can safely jit them -def jit_if_possible(func): +def jit_if_possible(func, *, static_argnames=("op",)): """Jit a function if use_jit.""" - jitted_func = functools.partial(jit, static_argnames=["op"])(func) + jitted_func = functools.partial(jit, static_argnames=list(static_argnames))(func) @functools.wraps(func) def wrapper(*args, **kwargs): @@ -1334,15 +1301,116 @@ def wrapper(*args, **kwargs): return wrapper -@jit_if_possible -def _proximal_jvp_f_pure(constraint, xf, constants, dc, eq_feasible_tangents, dxdc, op): +@jit_if_possible(static_argnames=("op", "dimc_per_thing", "eq_idx")) +def _proximal_get_tangents( + constraint, + xf, + v, + constants, + eq_feasible_tangents_T, + dxdc, + feasible_tangents, + dimc_per_thing, + eq_idx, + op="scaled_error", +): + uf, sfi, vtf = _get_fxh_inverse( + constraint, xf, constants, eq_feasible_tangents_T, op + ) + jvpfun = lambda u: _get_tangent( + constraint, + u, + xf, + constants, + uf, + sfi, + vtf, + dxdc, + feasible_tangents, + dimc_per_thing, + eq_idx, + op, + ) + return batched_vectorize( + jvpfun, + signature="(n)->(k)", + chunk_size=constraint._jac_chunk_size, + )(v) + + +def _get_tangent( + constraint, + v, + xf, + constants, + uf, + sfi, + vtf, + dxdc, + feasible_tangents, + dimc_per_thing, + eq_idx, + op, +): + # Note: This function is vectorized over v. So, v is expected to be 1D array + # of size self.dim_x. + + # v contains self._args DoFs from eq and other objects (like coils, surfaces + # etc), we want jvp_f to only get parts from equilibrium, not other things + vs = jnp.split(v, np.cumsum(dimc_per_thing)) + # This is (dF/dx)^-1 * dF/dc # noqa : E800 + dfdc = _proximal_jvp_f_pure( + constraint, xf, constants, vs[eq_idx], uf, sfi, vtf, dxdc, op + ) + # broadcasting against multiple things + dfdcs = [jnp.zeros(dim) for dim in dimc_per_thing] + dfdcs[eq_idx] = dfdc + # note that dfdc.size != vs[self._eq_idx].size + # dfdc has the size of reduced state vector of the equilibrium + # but vs[self._eq_idx] has the size of self._args DoFs + dfdc = jnp.concatenate(dfdcs) + + # We try to find dG/dc - dG/dx * (dF/dx)^-1 * dF/dc + # where G is the objective function. Since DESC stores x and c in the same + # vector, instead of multiple JVP calls, we will just find a tangent direction + # that will give us the same result. + # For making the explanation clear, assume J is the Jacobian of the objective + # function with respect to the full state vector (both x and c). Then, + # dG/dc = J @ (tangent vectors in c direction) + # dG/dx = J @ (tangent vectors in x direction) + # So, dG/dc - dG/dx * (dF/dx)^-1 * dF/dc can be written as + # J @ [(tangent vectors in c direction) - (tangent vectors in x direction)@dfdc] + # Note: We will never form full Jacobian J, we will just compute the above + # expression by JVPs. + dxdcv = jnp.concatenate( + [ + *vs[:eq_idx], + dxdc @ vs[eq_idx], # Rb_lmn, Zb_lmn to full eq state vector + *vs[eq_idx + 1 :], + ] + ) + tangent = dxdcv - feasible_tangents @ dfdc + return tangent + + +def _get_fxh_inverse(constraint, xf, constants, eq_feasible_tangents_T, op): + # This is the transpose of dF/dx + Fxh = getattr(constraint, "jvp_" + op)(eq_feasible_tangents_T, xf, constants) + cutoff = jnp.finfo(Fxh.dtype).eps * max(Fxh.shape) + uf, sf, vtf = jnp.linalg.svd(Fxh, full_matrices=False) + sf += sf[-1] # add a tiny bit of regularization + sfi = jnp.where(sf < cutoff * sf[0], 0, 1 / sf) + return uf, sfi, vtf + + +def _proximal_jvp_f_pure(constraint, xf, constants, dc, uf, sfi, vtf, dxdc, op): # Note: This function is called by _get_tangent which is vectorized over v # (v is called dc in this function). So, dc is expected to be 1D array # of same size as full equilibrium state vector. This function returns a 1D array. # here we are forming (dF/dx)^-1 @ dF/dc - # where Fxh is dF/dx and Fc is dF/dc - Fxh = getattr(constraint, "jvp_" + op)(eq_feasible_tangents.T, xf, constants).T + # where Fc is dF/dc and (dF/dx)^-1 is given by uft, sfi and vtft which are from the + # SVD of dF/dx computed in _get_fxh_inverse. # Our compute functions never include variables like Rb_lmn, Zb_lmn etc. So, # taking the JVP in just dc direction will give 0. To prevent this, we use dxdc # which is the dx/dc matrix and convert the Rb_lmn to R_lmn entries etc. @@ -1350,11 +1418,11 @@ def _proximal_jvp_f_pure(constraint, xf, constants, dc, eq_feasible_tangents, dx # wrt all R_lmn coefficients that contribute to Rb_023. See BoundaryRSelfConsistency # for the relation between Rb_lmn and R_lmn. Fc = getattr(constraint, "jvp_" + op)(dxdc @ dc, xf, constants) - cutoff = jnp.finfo(Fxh.dtype).eps * max(Fxh.shape) - uf, sf, vtf = jnp.linalg.svd(Fxh, full_matrices=False) - sf += sf[-1] # add a tiny bit of regularization - sfi = jnp.where(sf < cutoff * sf[0], 0, 1 / sf) - return vtf.T @ (sfi * (uf.T @ Fc)) + # Note: keeping uf and vtf separate is more efficient than multiplying them to get a + # single inverse matrix that is computed once out of the batched operation for small + # batch sizes. For larger batch sizes, it can be more efficient to compute the full + # inverse matrix and do a single matmul, but this is omitted for now. + return uf @ (sfi * (vtf @ Fc)) @jit_if_possible From 8d635e83cf75b33f9d493563a028873348e2dfe8 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Wed, 10 Jun 2026 02:49:42 +0300 Subject: [PATCH 2/8] fix the decorator --- desc/optimize/_constraint_wrappers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index e684fddc44..9651f6d4e4 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -1285,8 +1285,10 @@ def __getattr__(self, name): # define these helper functions that are stateless so we can safely jit them -def jit_if_possible(func, *, static_argnames=("op",)): +def jit_if_possible(func=None, *, static_argnames=("op",)): """Jit a function if use_jit.""" + if func is None: + return functools.partial(jit_if_possible, static_argnames=static_argnames) jitted_func = functools.partial(jit, static_argnames=list(static_argnames))(func) @functools.wraps(func) From 03d00cb3e7e788e79daa58e03af5708d5d517d6d Mon Sep 17 00:00:00 2001 From: YigitElma Date: Wed, 10 Jun 2026 03:11:48 +0300 Subject: [PATCH 3/8] minor comment fixes --- desc/optimize/_constraint_wrappers.py | 34 +++++++++++++-------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 9651f6d4e4..f0f4d704eb 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -1057,14 +1057,14 @@ def grad(self, x, constants=None): gradient vector. """ - # We are looking for the gradient of L = 0.5 * G.T @ G - # Then, the gradient is ∇L = G.T @ J_of_G + # We are looking for the gradient of L = 0.5 * Gᵀ @ G + # Then, the gradient is ∇L = Gᵀ @ J_of_G # where J_of_G is the Jacobian of G with respect to the optimization variables # We explained getting J_of_G in the _jvp method. It is basically, - # J_of_G = ∇G @ [dc_tangents - (∇F @ dx_tangents) ^ -1 @ (∇F @ dc_tangents)] + # J_of_G = ∇G @ [dc_tangents - (∇F @ dx_tangents)⁻¹ @ (∇F @ dc_tangents)] # where ∇G is the Jacobian of G with respect to full state vector # and ∇F is the Jacobian of F with respect to full state vector. Then, - # ∇L = G.T @ ∇G @ [dc_tangents - (∇F @ dx_tangents) ^ -1 @ (∇F @ dc_tangents)] + # ∇L = Gᵀ @ ∇G @ [dc_tangents - (∇F @ dx_tangents)⁻¹ @ (∇F @ dc_tangents)] # We get the part in [] using the _proximal_get_tangents. v = jnp.eye(x.shape[0]) constants = setdefault(constants, [None, None]) @@ -1220,12 +1220,12 @@ def _jvp(self, v, x, constants=None, op="scaled_error"): # equilibrium such that # F(x+dx, c+dc) = 0 = F(x, c) + dF/dx * dx + dF/dc * dc # so that we can set F(x, c) = 0, from here we can solve for dx and get - # dx = - (dF/dx)^-1 * dF/dc * dc # noqa : E800 + # dx = - (dF/dx)⁻¹ * dF/dc * dc # noqa : E800 # We can then compute the Jacobian of the objective function with respect to c # G(x+dx, c+dc) = G(x, c) + dG/dx * dx + dG/dc * dc # substituting in dx we get - # G(x+dx, c+dc) = G(x, c) + [ dG/dc - dG/dx * (dF/dx)^-1 * dF/dc ]* dc - # and the Jacobian we want is dG/dc - dG/dx * (dF/dx)^-1 * dF/dc + # G(x+dx, c+dc) = G(x, c) + [ dG/dc - dG/dx * (dF/dx)⁻¹ * dF/dc ] * dc + # and the Jacobian we want is dG/dc - dG/dx * (dF/dx)⁻¹ * dF/dc # Note: This Jacobian can be obtained using JVPs in proper tangent directions. # First we will compute the tangent direction (see _get_tangent for details), @@ -1355,24 +1355,24 @@ def _get_tangent( op, ): # Note: This function is vectorized over v. So, v is expected to be 1D array - # of size self.dim_x. + # of size prox.dim_x. - # v contains self._args DoFs from eq and other objects (like coils, surfaces + # v contains prox._args DoFs from eq and other objects (like coils, surfaces # etc), we want jvp_f to only get parts from equilibrium, not other things vs = jnp.split(v, np.cumsum(dimc_per_thing)) - # This is (dF/dx)^-1 * dF/dc # noqa : E800 + # This is (dF/dx)⁻¹ * dF/dc # noqa : E800 dfdc = _proximal_jvp_f_pure( constraint, xf, constants, vs[eq_idx], uf, sfi, vtf, dxdc, op ) # broadcasting against multiple things dfdcs = [jnp.zeros(dim) for dim in dimc_per_thing] dfdcs[eq_idx] = dfdc - # note that dfdc.size != vs[self._eq_idx].size + # note that dfdc.size != vs[eq_idx].size # dfdc has the size of reduced state vector of the equilibrium - # but vs[self._eq_idx] has the size of self._args DoFs + # but vs[eq_idx] has the size of prox._args DoFs dfdc = jnp.concatenate(dfdcs) - # We try to find dG/dc - dG/dx * (dF/dx)^-1 * dF/dc + # We try to find dG/dc - dG/dx * (dF/dx)⁻¹ * dF/dc # where G is the objective function. Since DESC stores x and c in the same # vector, instead of multiple JVP calls, we will just find a tangent direction # that will give us the same result. @@ -1380,7 +1380,7 @@ def _get_tangent( # function with respect to the full state vector (both x and c). Then, # dG/dc = J @ (tangent vectors in c direction) # dG/dx = J @ (tangent vectors in x direction) - # So, dG/dc - dG/dx * (dF/dx)^-1 * dF/dc can be written as + # So, dG/dc - dG/dx * (dF/dx)⁻¹ * dF/dc can be written as # J @ [(tangent vectors in c direction) - (tangent vectors in x direction)@dfdc] # Note: We will never form full Jacobian J, we will just compute the above # expression by JVPs. @@ -1410,9 +1410,9 @@ def _proximal_jvp_f_pure(constraint, xf, constants, dc, uf, sfi, vtf, dxdc, op): # (v is called dc in this function). So, dc is expected to be 1D array # of same size as full equilibrium state vector. This function returns a 1D array. - # here we are forming (dF/dx)^-1 @ dF/dc - # where Fc is dF/dc and (dF/dx)^-1 is given by uft, sfi and vtft which are from the - # SVD of dF/dx computed in _get_fxh_inverse. + # here we are forming (dF/dx)⁻¹ @ dF/dc + # where Fc is dF/dc and (dF/dx)⁻¹ is given by uf, sfi and vtf which are from the + # SVD of dF/dxᵀ computed in _get_fxh_inverse. # Our compute functions never include variables like Rb_lmn, Zb_lmn etc. So, # taking the JVP in just dc direction will give 0. To prevent this, we use dxdc # which is the dx/dc matrix and convert the Rb_lmn to R_lmn entries etc. From e6640e4103d9dd7370bcb6265c173e06c8d4ac61 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Wed, 10 Jun 2026 03:14:33 +0300 Subject: [PATCH 4/8] update changelog --- CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f1a2a5651e..f90fb676e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ Changelog ========= +Performance Improvements + +- More efficient `ProximalProjection` jacobians especially if the `ForceBalance` constraint uses a small `jac_chunk_size`. + + +v0.17.2 +------- + New Features - Adds ``desc.objectives.DeflationOperator``, a new objective class which can be used to apply deflation techniques to equilibrium and optimization problems to find multiple local minima or multiple solutions from a single initial point, either by wrapping an existing ``desc.objectives._Objective`` object or by including as an additional penalty or constraint. Also adds a tutorial showing this functionality. From dc96f78be5401896515006d00dedf06086957638 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Wed, 10 Jun 2026 04:45:33 +0300 Subject: [PATCH 5/8] take the svd part of compile graph, maybe faster? still jitted but not fused into a giant program --- desc/optimize/_constraint_wrappers.py | 30 +++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index f0f4d704eb..4ffcddd61c 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -1069,12 +1069,21 @@ def grad(self, x, constants=None): v = jnp.eye(x.shape[0]) constants = setdefault(constants, [None, None]) xg, xf = self._update_equilibrium(x, store=True) + uf, sfi, vtf = _get_fxh_inverse( + self._constraint, + xf, + constants[1], + self._eq_solve_objective._feasible_tangents.T, + "scaled_error", + ) tangents = _proximal_get_tangents( self._constraint, xf, v, constants[1], - self._eq_solve_objective._feasible_tangents.T, + uf, + sfi, + vtf, self._dxdc, self._feasible_tangents, self._dimc_per_thing, @@ -1236,12 +1245,21 @@ def _jvp(self, v, x, constants=None, op="scaled_error"): # we don't need to divide this part into blocked and batched because # self._constraint._deriv_mode will handle it + uf, sfi, vtf = _get_fxh_inverse( + self._constraint, + xf, + constants[1], + self._eq_solve_objective._feasible_tangents.T, + op, + ) tangents = _proximal_get_tangents( self._constraint, xf, v, constants[1], - self._eq_solve_objective._feasible_tangents.T, + uf, + sfi, + vtf, self._dxdc, self._feasible_tangents, self._dimc_per_thing, @@ -1309,16 +1327,15 @@ def _proximal_get_tangents( xf, v, constants, - eq_feasible_tangents_T, + uf, + sfi, + vtf, dxdc, feasible_tangents, dimc_per_thing, eq_idx, op="scaled_error", ): - uf, sfi, vtf = _get_fxh_inverse( - constraint, xf, constants, eq_feasible_tangents_T, op - ) jvpfun = lambda u: _get_tangent( constraint, u, @@ -1395,6 +1412,7 @@ def _get_tangent( return tangent +@jit_if_possible def _get_fxh_inverse(constraint, xf, constants, eq_feasible_tangents_T, op): # This is the transpose of dF/dx Fxh = getattr(constraint, "jvp_" + op)(eq_feasible_tangents_T, xf, constants) From 804dc8ffbd1c7ae6c6871b81e9818f85a917eac8 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Wed, 10 Jun 2026 21:31:55 +0300 Subject: [PATCH 6/8] move svd back inside but still jit and hope compiler hoist the loop invariant --- desc/optimize/_constraint_wrappers.py | 66 ++++++++------------------- 1 file changed, 18 insertions(+), 48 deletions(-) diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 4ffcddd61c..8ecb7ff285 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -1069,21 +1069,12 @@ def grad(self, x, constants=None): v = jnp.eye(x.shape[0]) constants = setdefault(constants, [None, None]) xg, xf = self._update_equilibrium(x, store=True) - uf, sfi, vtf = _get_fxh_inverse( - self._constraint, - xf, - constants[1], - self._eq_solve_objective._feasible_tangents.T, - "scaled_error", - ) tangents = _proximal_get_tangents( self._constraint, xf, v, constants[1], - uf, - sfi, - vtf, + self._eq_solve_objective._feasible_tangents, self._dxdc, self._feasible_tangents, self._dimc_per_thing, @@ -1245,21 +1236,12 @@ def _jvp(self, v, x, constants=None, op="scaled_error"): # we don't need to divide this part into blocked and batched because # self._constraint._deriv_mode will handle it - uf, sfi, vtf = _get_fxh_inverse( - self._constraint, - xf, - constants[1], - self._eq_solve_objective._feasible_tangents.T, - op, - ) tangents = _proximal_get_tangents( self._constraint, xf, v, constants[1], - uf, - sfi, - vtf, + self._eq_solve_objective._feasible_tangents, self._dxdc, self._feasible_tangents, self._dimc_per_thing, @@ -1327,9 +1309,7 @@ def _proximal_get_tangents( xf, v, constants, - uf, - sfi, - vtf, + eq_feasible_tangents, dxdc, feasible_tangents, dimc_per_thing, @@ -1341,9 +1321,7 @@ def _proximal_get_tangents( u, xf, constants, - uf, - sfi, - vtf, + eq_feasible_tangents, dxdc, feasible_tangents, dimc_per_thing, @@ -1357,14 +1335,13 @@ def _proximal_get_tangents( )(v) +@jit_if_possible(static_argnames=("op", "dimc_per_thing", "eq_idx")) def _get_tangent( constraint, v, xf, constants, - uf, - sfi, - vtf, + eq_feasible_tangents, dxdc, feasible_tangents, dimc_per_thing, @@ -1379,7 +1356,7 @@ def _get_tangent( vs = jnp.split(v, np.cumsum(dimc_per_thing)) # This is (dF/dx)⁻¹ * dF/dc # noqa : E800 dfdc = _proximal_jvp_f_pure( - constraint, xf, constants, vs[eq_idx], uf, sfi, vtf, dxdc, op + constraint, xf, constants, vs[eq_idx], eq_feasible_tangents, dxdc, op ) # broadcasting against multiple things dfdcs = [jnp.zeros(dim) for dim in dimc_per_thing] @@ -1412,25 +1389,18 @@ def _get_tangent( return tangent -@jit_if_possible -def _get_fxh_inverse(constraint, xf, constants, eq_feasible_tangents_T, op): - # This is the transpose of dF/dx - Fxh = getattr(constraint, "jvp_" + op)(eq_feasible_tangents_T, xf, constants) - cutoff = jnp.finfo(Fxh.dtype).eps * max(Fxh.shape) - uf, sf, vtf = jnp.linalg.svd(Fxh, full_matrices=False) - sf += sf[-1] # add a tiny bit of regularization - sfi = jnp.where(sf < cutoff * sf[0], 0, 1 / sf) - return uf, sfi, vtf - - -def _proximal_jvp_f_pure(constraint, xf, constants, dc, uf, sfi, vtf, dxdc, op): +def _proximal_jvp_f_pure(constraint, xf, constants, dc, eq_feasible_tangents, dxdc, op): # Note: This function is called by _get_tangent which is vectorized over v # (v is called dc in this function). So, dc is expected to be 1D array # of same size as full equilibrium state vector. This function returns a 1D array. # here we are forming (dF/dx)⁻¹ @ dF/dc - # where Fc is dF/dc and (dF/dx)⁻¹ is given by uf, sfi and vtf which are from the - # SVD of dF/dxᵀ computed in _get_fxh_inverse. + # where Fxh is dF/dx and Fc is dF/dc. + # Note: Fxh and its SVD do not depend on dc (the vectorized argument). Since the + # whole tangent computation is jitted as one program, we rely on the compiler to + # hoist this loop-invariant SVD out of the batched scan/vmap rather than + # recomputing it for every tangent. + Fxh = getattr(constraint, "jvp_" + op)(eq_feasible_tangents.T, xf, constants) # Our compute functions never include variables like Rb_lmn, Zb_lmn etc. So, # taking the JVP in just dc direction will give 0. To prevent this, we use dxdc # which is the dx/dc matrix and convert the Rb_lmn to R_lmn entries etc. @@ -1438,10 +1408,10 @@ def _proximal_jvp_f_pure(constraint, xf, constants, dc, uf, sfi, vtf, dxdc, op): # wrt all R_lmn coefficients that contribute to Rb_023. See BoundaryRSelfConsistency # for the relation between Rb_lmn and R_lmn. Fc = getattr(constraint, "jvp_" + op)(dxdc @ dc, xf, constants) - # Note: keeping uf and vtf separate is more efficient than multiplying them to get a - # single inverse matrix that is computed once out of the batched operation for small - # batch sizes. For larger batch sizes, it can be more efficient to compute the full - # inverse matrix and do a single matmul, but this is omitted for now. + cutoff = jnp.finfo(Fxh.dtype).eps * max(Fxh.shape) + uf, sf, vtf = jnp.linalg.svd(Fxh, full_matrices=False) + sf += sf[-1] # add a tiny bit of regularization + sfi = jnp.where(sf < cutoff * sf[0], 0, 1 / sf) return uf @ (sfi * (vtf @ Fc)) From 30f5f11f47cc9966a8f2a12ff8cbd9d41fb312df Mon Sep 17 00:00:00 2001 From: YigitElma Date: Wed, 10 Jun 2026 22:01:06 +0300 Subject: [PATCH 7/8] add a benchmark --- tests/benchmarks/benchmark_cpu_small.py | 23 +++++++++++++++++++++++ tests/benchmarks/benchmark_gpu_small.py | 23 +++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/tests/benchmarks/benchmark_cpu_small.py b/tests/benchmarks/benchmark_cpu_small.py index b28667ada3..fea6558f41 100644 --- a/tests/benchmarks/benchmark_cpu_small.py +++ b/tests/benchmarks/benchmark_cpu_small.py @@ -346,6 +346,29 @@ def run(x, prox): benchmark.pedantic(run, args=(x, prox), rounds=20, iterations=1) +@pytest.mark.slow +@pytest.mark.benchmark +def test_proximal_jac_atf_chunked(benchmark): + """Benchmark computing jacobian of constrained proximal projection.""" + eq = desc.examples.get("ATF") + grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=np.linspace(0.1, 1, 10)) + objective = ObjectiveFunction(QuasisymmetryTwoTerm(eq, grid=grid)) + # chunk the computation, total size is 252, so this should take at most + # 2.5x the unchunked case above + constraint = ObjectiveFunction(ForceBalance(eq), jac_chunk_size=100) + prox = ProximalProjection( + objective, constraint, eq, solve_options={"solve_during_proximal_build": False} + ) + prox.build() + x = prox.x(eq) + prox.jac_scaled_error(x).block_until_ready() + + def run(x, prox): + prox.jac_scaled_error(x).block_until_ready() + + benchmark.pedantic(run, args=(x, prox), rounds=10, iterations=1) + + @pytest.mark.slow @pytest.mark.benchmark def test_proximal_jac_atf_with_eq_update(benchmark): diff --git a/tests/benchmarks/benchmark_gpu_small.py b/tests/benchmarks/benchmark_gpu_small.py index fe58b2b39d..f68ce08431 100644 --- a/tests/benchmarks/benchmark_gpu_small.py +++ b/tests/benchmarks/benchmark_gpu_small.py @@ -349,6 +349,29 @@ def run(x, prox): benchmark.pedantic(run, args=(x, prox), rounds=20, iterations=1) +@pytest.mark.slow +@pytest.mark.benchmark +def test_proximal_jac_atf_chunked(benchmark): + """Benchmark computing jacobian of constrained proximal projection.""" + eq = desc.examples.get("ATF") + grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=np.linspace(0.1, 1, 10)) + objective = ObjectiveFunction(QuasisymmetryTwoTerm(eq, grid=grid)) + # chunk the computation, total size is 252, so this should take at most + # 2.5x the unchunked case above + constraint = ObjectiveFunction(ForceBalance(eq), jac_chunk_size=100) + prox = ProximalProjection( + objective, constraint, eq, solve_options={"solve_during_proximal_build": False} + ) + prox.build() + x = prox.x(eq) + prox.jac_scaled_error(x).block_until_ready() + + def run(x, prox): + prox.jac_scaled_error(x).block_until_ready() + + benchmark.pedantic(run, args=(x, prox), rounds=10, iterations=1) + + @pytest.mark.slow @pytest.mark.benchmark def test_proximal_jac_atf_with_eq_update(benchmark): From dd9f654d4dca00550c2d61c94f2466bba3be50ee Mon Sep 17 00:00:00 2001 From: YigitElma Date: Wed, 10 Jun 2026 22:44:54 +0300 Subject: [PATCH 8/8] minor docs --- desc/optimize/_constraint_wrappers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 8ecb7ff285..5bacc75131 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -1335,7 +1335,6 @@ def _proximal_get_tangents( )(v) -@jit_if_possible(static_argnames=("op", "dimc_per_thing", "eq_idx")) def _get_tangent( constraint, v, @@ -1395,7 +1394,7 @@ def _proximal_jvp_f_pure(constraint, xf, constants, dc, eq_feasible_tangents, dx # of same size as full equilibrium state vector. This function returns a 1D array. # here we are forming (dF/dx)⁻¹ @ dF/dc - # where Fxh is dF/dx and Fc is dF/dc. + # where Fxh is dF/dxᵀ and Fc is dF/dc. # Note: Fxh and its SVD do not depend on dc (the vectorized argument). Since the # whole tangent computation is jitted as one program, we rely on the compiler to # hoist this loop-invariant SVD out of the batched scan/vmap rather than