diff --git a/CHANGELOG.md b/CHANGELOG.md index f1a2a5651e..f90fb676e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ Changelog ========= +Performance Improvements + +- More efficient `ProximalProjection` jacobians especially if the `ForceBalance` constraint uses a small `jac_chunk_size`. + + +v0.17.2 +------- + New Features - Adds ``desc.objectives.DeflationOperator``, a new objective class which can be used to apply deflation techniques to equilibrium and optimization problems to find multiple local minima or multiple solutions from a single initial point, either by wrapping an existing ``desc.objectives._Objective`` object or by including as an additional penalty or constraint. Also adds a tutorial showing this functionality. diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 1041b2a1fb..5bacc75131 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -752,9 +752,11 @@ def build(self, use_jit=None, verbose=1): # noqa: C901 # we remove the R_lmn, Z_lmn, L_lmn, Ra_n, Za_n from the equilibrium params # dimc_per_thing accounts for that, don't confuse it with reduced state vector self._dimc_per_thing = [t.dim_x for t in self.things] - self._dimc_per_thing[self._eq_idx] = np.sum( - [self._eq.dimensions[arg] for arg in self._args] + self._dimc_per_thing[self._eq_idx] = int( + np.sum([self._eq.dimensions[arg] for arg in self._args]) ) + # we will need to set this static attribute, only possible if tuple + self._dimc_per_thing = tuple(self._dimc_per_thing) # equivalent matrix for A[unfixed_idx] @ D @ Z == A @ feasible_tangents self._feasible_tangents = jnp.eye(self._objective.dim_x) @@ -1055,24 +1057,30 @@ def grad(self, x, constants=None): gradient vector. """ - # We are looking for the gradient of L = 0.5 * G.T @ G - # Then, the gradient is ∇L = G.T @ J_of_G + # We are looking for the gradient of L = 0.5 * Gᵀ @ G + # Then, the gradient is ∇L = Gᵀ @ J_of_G # where J_of_G is the Jacobian of G with respect to the optimization variables # We explained getting J_of_G in the _jvp method. It is basically, - # J_of_G = ∇G @ [dc_tangents - (∇F @ dx_tangents) ^ -1 @ (∇F @ dc_tangents)] + # J_of_G = ∇G @ [dc_tangents - (∇F @ dx_tangents)⁻¹ @ (∇F @ dc_tangents)] # where ∇G is the Jacobian of G with respect to full state vector # and ∇F is the Jacobian of F with respect to full state vector. Then, - # ∇L = G.T @ ∇G @ [dc_tangents - (∇F @ dx_tangents) ^ -1 @ (∇F @ dc_tangents)] - # We get the part in [] using the _get_tangent method. + # ∇L = Gᵀ @ ∇G @ [dc_tangents - (∇F @ dx_tangents)⁻¹ @ (∇F @ dc_tangents)] + # We get the part in [] using the _proximal_get_tangents. v = jnp.eye(x.shape[0]) constants = setdefault(constants, [None, None]) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._get_tangent(u, xf, constants, op="scaled_error") - tangents = batched_vectorize( - jvpfun, - signature="(n)->(k)", - chunk_size=self._constraint._jac_chunk_size, - )(v) + tangents = _proximal_get_tangents( + self._constraint, + xf, + v, + constants[1], + self._eq_solve_objective._feasible_tangents, + self._dxdc, + self._feasible_tangents, + self._dimc_per_thing, + self._eq_idx, + "scaled_error", + ) g = self._objective.compute_scaled_error(xg, constants[0]) g_vjp = self._objective.vjp_scaled_error(g, xg, constants[0]) return tangents @ g_vjp @@ -1212,12 +1220,12 @@ def _jvp(self, v, x, constants=None, op="scaled_error"): # equilibrium such that # F(x+dx, c+dc) = 0 = F(x, c) + dF/dx * dx + dF/dc * dc # so that we can set F(x, c) = 0, from here we can solve for dx and get - # dx = - (dF/dx)^-1 * dF/dc * dc # noqa : E800 + # dx = - (dF/dx)⁻¹ * dF/dc * dc # noqa : E800 # We can then compute the Jacobian of the objective function with respect to c # G(x+dx, c+dc) = G(x, c) + dG/dx * dx + dG/dc * dc # substituting in dx we get - # G(x+dx, c+dc) = G(x, c) + [ dG/dc - dG/dx * (dF/dx)^-1 * dF/dc ]* dc - # and the Jacobian we want is dG/dc - dG/dx * (dF/dx)^-1 * dF/dc + # G(x+dx, c+dc) = G(x, c) + [ dG/dc - dG/dx * (dF/dx)⁻¹ * dF/dc ] * dc + # and the Jacobian we want is dG/dc - dG/dx * (dF/dx)⁻¹ * dF/dc # Note: This Jacobian can be obtained using JVPs in proper tangent directions. # First we will compute the tangent direction (see _get_tangent for details), @@ -1228,12 +1236,18 @@ def _jvp(self, v, x, constants=None, op="scaled_error"): # we don't need to divide this part into blocked and batched because # self._constraint._deriv_mode will handle it - jvpfun = lambda u: self._get_tangent(u, xf, constants, op=op) - tangents = batched_vectorize( - jvpfun, - signature="(n)->(k)", - chunk_size=self._constraint._jac_chunk_size, - )(v) + tangents = _proximal_get_tangents( + self._constraint, + xf, + v, + constants[1], + self._eq_solve_objective._feasible_tangents, + self._dxdc, + self._feasible_tangents, + self._dimc_per_thing, + self._eq_idx, + op, + ) if self._objective._deriv_mode == "batched": # objective's method already know about its jac_chunk_size @@ -1246,53 +1260,6 @@ def _jvp(self, v, x, constants=None, op="scaled_error"): op, ) - def _get_tangent(self, v, xf, constants, op): - # Note: This function is vectorized over v. So, v is expected to be 1D array - # of size self.dim_x. - - # v contains self._args DoFs from eq and other objects (like coils, surfaces - # etc), we want jvp_f to only get parts from equilibrium, not other things - vs = jnp.split(v, np.cumsum(self._dimc_per_thing)) - # This is (dF/dx)^-1 * dF/dc # noqa : E800 - dfdc = _proximal_jvp_f_pure( - self._constraint, - xf, - constants[1], - vs[self._eq_idx], - self._eq_solve_objective._feasible_tangents, - self._dxdc, - op, - ) - # broadcasting against multiple things - dfdcs = [jnp.zeros(dim) for dim in self._dimc_per_thing] - dfdcs[self._eq_idx] = dfdc - # note that dfdc.size != vs[self._eq_idx].size - # dfdc has the size of reduced state vector of the equilibrium - # but vs[self._eq_idx] has the size of self._args DoFs - dfdc = jnp.concatenate(dfdcs) - - # We try to find dG/dc - dG/dx * (dF/dx)^-1 * dF/dc - # where G is the objective function. Since DESC stores x and c in the same - # vector, instead of multiple JVP calls, we will just find a tangent direction - # that will give us the same result. - # For making the explanation clear, assume J is the Jacobian of the objective - # function with respect to the full state vector (both x and c). Then, - # dG/dc = J @ (tangent vectors in c direction) - # dG/dx = J @ (tangent vectors in x direction) - # So, dG/dc - dG/dx * (dF/dx)^-1 * dF/dc can be written as - # J @ [(tangent vectors in c direction) - (tangent vectors in x direction)@dfdc] - # Note: We will never form full Jacobian J, we will just compute the above - # expression by JVPs. - dxdcv = jnp.concatenate( - [ - *vs[: self._eq_idx], - self._dxdc @ vs[self._eq_idx], # Rb_lmn, Zb_lmn to full eq state vector - *vs[self._eq_idx + 1 :], - ] - ) - tangent = dxdcv - self._feasible_tangents @ dfdc - return tangent - @property def constants(self): """list: constant parameters for each sub-objective.""" @@ -1318,9 +1285,11 @@ def __getattr__(self, name): # define these helper functions that are stateless so we can safely jit them -def jit_if_possible(func): +def jit_if_possible(func=None, *, static_argnames=("op",)): """Jit a function if use_jit.""" - jitted_func = functools.partial(jit, static_argnames=["op"])(func) + if func is None: + return functools.partial(jit_if_possible, static_argnames=static_argnames) + jitted_func = functools.partial(jit, static_argnames=list(static_argnames))(func) @functools.wraps(func) def wrapper(*args, **kwargs): @@ -1334,15 +1303,103 @@ def wrapper(*args, **kwargs): return wrapper -@jit_if_possible +@jit_if_possible(static_argnames=("op", "dimc_per_thing", "eq_idx")) +def _proximal_get_tangents( + constraint, + xf, + v, + constants, + eq_feasible_tangents, + dxdc, + feasible_tangents, + dimc_per_thing, + eq_idx, + op="scaled_error", +): + jvpfun = lambda u: _get_tangent( + constraint, + u, + xf, + constants, + eq_feasible_tangents, + dxdc, + feasible_tangents, + dimc_per_thing, + eq_idx, + op, + ) + return batched_vectorize( + jvpfun, + signature="(n)->(k)", + chunk_size=constraint._jac_chunk_size, + )(v) + + +def _get_tangent( + constraint, + v, + xf, + constants, + eq_feasible_tangents, + dxdc, + feasible_tangents, + dimc_per_thing, + eq_idx, + op, +): + # Note: This function is vectorized over v. So, v is expected to be 1D array + # of size prox.dim_x. + + # v contains prox._args DoFs from eq and other objects (like coils, surfaces + # etc), we want jvp_f to only get parts from equilibrium, not other things + vs = jnp.split(v, np.cumsum(dimc_per_thing)) + # This is (dF/dx)⁻¹ * dF/dc # noqa : E800 + dfdc = _proximal_jvp_f_pure( + constraint, xf, constants, vs[eq_idx], eq_feasible_tangents, dxdc, op + ) + # broadcasting against multiple things + dfdcs = [jnp.zeros(dim) for dim in dimc_per_thing] + dfdcs[eq_idx] = dfdc + # note that dfdc.size != vs[eq_idx].size + # dfdc has the size of reduced state vector of the equilibrium + # but vs[eq_idx] has the size of prox._args DoFs + dfdc = jnp.concatenate(dfdcs) + + # We try to find dG/dc - dG/dx * (dF/dx)⁻¹ * dF/dc + # where G is the objective function. Since DESC stores x and c in the same + # vector, instead of multiple JVP calls, we will just find a tangent direction + # that will give us the same result. + # For making the explanation clear, assume J is the Jacobian of the objective + # function with respect to the full state vector (both x and c). Then, + # dG/dc = J @ (tangent vectors in c direction) + # dG/dx = J @ (tangent vectors in x direction) + # So, dG/dc - dG/dx * (dF/dx)⁻¹ * dF/dc can be written as + # J @ [(tangent vectors in c direction) - (tangent vectors in x direction)@dfdc] + # Note: We will never form full Jacobian J, we will just compute the above + # expression by JVPs. + dxdcv = jnp.concatenate( + [ + *vs[:eq_idx], + dxdc @ vs[eq_idx], # Rb_lmn, Zb_lmn to full eq state vector + *vs[eq_idx + 1 :], + ] + ) + tangent = dxdcv - feasible_tangents @ dfdc + return tangent + + def _proximal_jvp_f_pure(constraint, xf, constants, dc, eq_feasible_tangents, dxdc, op): # Note: This function is called by _get_tangent which is vectorized over v # (v is called dc in this function). So, dc is expected to be 1D array # of same size as full equilibrium state vector. This function returns a 1D array. - # here we are forming (dF/dx)^-1 @ dF/dc - # where Fxh is dF/dx and Fc is dF/dc - Fxh = getattr(constraint, "jvp_" + op)(eq_feasible_tangents.T, xf, constants).T + # here we are forming (dF/dx)⁻¹ @ dF/dc + # where Fxh is dF/dxᵀ and Fc is dF/dc. + # Note: Fxh and its SVD do not depend on dc (the vectorized argument). Since the + # whole tangent computation is jitted as one program, we rely on the compiler to + # hoist this loop-invariant SVD out of the batched scan/vmap rather than + # recomputing it for every tangent. + Fxh = getattr(constraint, "jvp_" + op)(eq_feasible_tangents.T, xf, constants) # Our compute functions never include variables like Rb_lmn, Zb_lmn etc. So, # taking the JVP in just dc direction will give 0. To prevent this, we use dxdc # which is the dx/dc matrix and convert the Rb_lmn to R_lmn entries etc. @@ -1354,7 +1411,7 @@ def _proximal_jvp_f_pure(constraint, xf, constants, dc, eq_feasible_tangents, dx uf, sf, vtf = jnp.linalg.svd(Fxh, full_matrices=False) sf += sf[-1] # add a tiny bit of regularization sfi = jnp.where(sf < cutoff * sf[0], 0, 1 / sf) - return vtf.T @ (sfi * (uf.T @ Fc)) + return uf @ (sfi * (vtf @ Fc)) @jit_if_possible diff --git a/tests/benchmarks/benchmark_cpu_small.py b/tests/benchmarks/benchmark_cpu_small.py index b28667ada3..fea6558f41 100644 --- a/tests/benchmarks/benchmark_cpu_small.py +++ b/tests/benchmarks/benchmark_cpu_small.py @@ -346,6 +346,29 @@ def run(x, prox): benchmark.pedantic(run, args=(x, prox), rounds=20, iterations=1) +@pytest.mark.slow +@pytest.mark.benchmark +def test_proximal_jac_atf_chunked(benchmark): + """Benchmark computing jacobian of constrained proximal projection.""" + eq = desc.examples.get("ATF") + grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=np.linspace(0.1, 1, 10)) + objective = ObjectiveFunction(QuasisymmetryTwoTerm(eq, grid=grid)) + # chunk the computation, total size is 252, so this should take at most + # 2.5x the unchunked case above + constraint = ObjectiveFunction(ForceBalance(eq), jac_chunk_size=100) + prox = ProximalProjection( + objective, constraint, eq, solve_options={"solve_during_proximal_build": False} + ) + prox.build() + x = prox.x(eq) + prox.jac_scaled_error(x).block_until_ready() + + def run(x, prox): + prox.jac_scaled_error(x).block_until_ready() + + benchmark.pedantic(run, args=(x, prox), rounds=10, iterations=1) + + @pytest.mark.slow @pytest.mark.benchmark def test_proximal_jac_atf_with_eq_update(benchmark): diff --git a/tests/benchmarks/benchmark_gpu_small.py b/tests/benchmarks/benchmark_gpu_small.py index fe58b2b39d..f68ce08431 100644 --- a/tests/benchmarks/benchmark_gpu_small.py +++ b/tests/benchmarks/benchmark_gpu_small.py @@ -349,6 +349,29 @@ def run(x, prox): benchmark.pedantic(run, args=(x, prox), rounds=20, iterations=1) +@pytest.mark.slow +@pytest.mark.benchmark +def test_proximal_jac_atf_chunked(benchmark): + """Benchmark computing jacobian of constrained proximal projection.""" + eq = desc.examples.get("ATF") + grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=np.linspace(0.1, 1, 10)) + objective = ObjectiveFunction(QuasisymmetryTwoTerm(eq, grid=grid)) + # chunk the computation, total size is 252, so this should take at most + # 2.5x the unchunked case above + constraint = ObjectiveFunction(ForceBalance(eq), jac_chunk_size=100) + prox = ProximalProjection( + objective, constraint, eq, solve_options={"solve_during_proximal_build": False} + ) + prox.build() + x = prox.x(eq) + prox.jac_scaled_error(x).block_until_ready() + + def run(x, prox): + prox.jac_scaled_error(x).block_until_ready() + + benchmark.pedantic(run, args=(x, prox), rounds=10, iterations=1) + + @pytest.mark.slow @pytest.mark.benchmark def test_proximal_jac_atf_with_eq_update(benchmark):