Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
Changelog
=========

Performance Improvements

- More efficient `ProximalProjection` jacobians especially if the `ForceBalance` constraint uses a small `jac_chunk_size`.


v0.17.2
-------

New Features

- Adds ``desc.objectives.DeflationOperator``, a new objective class which can be used to apply deflation techniques to equilibrium and optimization problems to find multiple local minima or multiple solutions from a single initial point, either by wrapping an existing ``desc.objectives._Objective`` object or by including as an additional penalty or constraint. Also adds a tutorial showing this functionality.
Expand Down
209 changes: 133 additions & 76 deletions desc/optimize/_constraint_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,9 +752,11 @@
# we remove the R_lmn, Z_lmn, L_lmn, Ra_n, Za_n from the equilibrium params
# dimc_per_thing accounts for that, don't confuse it with reduced state vector
self._dimc_per_thing = [t.dim_x for t in self.things]
self._dimc_per_thing[self._eq_idx] = np.sum(
[self._eq.dimensions[arg] for arg in self._args]
self._dimc_per_thing[self._eq_idx] = int(
np.sum([self._eq.dimensions[arg] for arg in self._args])
)
# we will need to set this static attribute, only possible if tuple
self._dimc_per_thing = tuple(self._dimc_per_thing)

# equivalent matrix for A[unfixed_idx] @ D @ Z == A @ feasible_tangents
self._feasible_tangents = jnp.eye(self._objective.dim_x)
Expand Down Expand Up @@ -1055,24 +1057,30 @@
gradient vector.

"""
# We are looking for the gradient of L = 0.5 * G.T @ G
# Then, the gradient is ∇L = G.T @ J_of_G
# We are looking for the gradient of L = 0.5 * Gᵀ @ G
# Then, the gradient is ∇L = Gᵀ @ J_of_G
# where J_of_G is the Jacobian of G with respect to the optimization variables
# We explained getting J_of_G in the _jvp method. It is basically,
# J_of_G = ∇G @ [dc_tangents - (∇F @ dx_tangents) ^ -1 @ (∇F @ dc_tangents)]
# J_of_G = ∇G @ [dc_tangents - (∇F @ dx_tangents)⁻¹ @ (∇F @ dc_tangents)]
# where ∇G is the Jacobian of G with respect to full state vector
# and ∇F is the Jacobian of F with respect to full state vector. Then,
# ∇L = G.T @ ∇G @ [dc_tangents - (∇F @ dx_tangents) ^ -1 @ (∇F @ dc_tangents)]
# We get the part in [] using the _get_tangent method.
# ∇L = Gᵀ @ ∇G @ [dc_tangents - (∇F @ dx_tangents)⁻¹ @ (∇F @ dc_tangents)]
# We get the part in [] using the _proximal_get_tangents.
v = jnp.eye(x.shape[0])
constants = setdefault(constants, [None, None])
xg, xf = self._update_equilibrium(x, store=True)
jvpfun = lambda u: self._get_tangent(u, xf, constants, op="scaled_error")
tangents = batched_vectorize(
jvpfun,
signature="(n)->(k)",
chunk_size=self._constraint._jac_chunk_size,
)(v)
tangents = _proximal_get_tangents(

Check warning on line 1072 in desc/optimize/_constraint_wrappers.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/_constraint_wrappers.py#L1072

Added line #L1072 was not covered by tests
self._constraint,
xf,
v,
constants[1],
self._eq_solve_objective._feasible_tangents,
self._dxdc,
self._feasible_tangents,
self._dimc_per_thing,
self._eq_idx,
"scaled_error",
)
g = self._objective.compute_scaled_error(xg, constants[0])
g_vjp = self._objective.vjp_scaled_error(g, xg, constants[0])
return tangents @ g_vjp
Expand Down Expand Up @@ -1212,12 +1220,12 @@
# equilibrium such that
# F(x+dx, c+dc) = 0 = F(x, c) + dF/dx * dx + dF/dc * dc
# so that we can set F(x, c) = 0, from here we can solve for dx and get
# dx = - (dF/dx)^-1 * dF/dc * dc # noqa : E800
# dx = - (dF/dx)⁻¹ * dF/dc * dc # noqa : E800
# We can then compute the Jacobian of the objective function with respect to c
# G(x+dx, c+dc) = G(x, c) + dG/dx * dx + dG/dc * dc
# substituting in dx we get
# G(x+dx, c+dc) = G(x, c) + [ dG/dc - dG/dx * (dF/dx)^-1 * dF/dc ]* dc
# and the Jacobian we want is dG/dc - dG/dx * (dF/dx)^-1 * dF/dc
# G(x+dx, c+dc) = G(x, c) + [ dG/dc - dG/dx * (dF/dx)⁻¹ * dF/dc ] * dc
# and the Jacobian we want is dG/dc - dG/dx * (dF/dx)⁻¹ * dF/dc

# Note: This Jacobian can be obtained using JVPs in proper tangent directions.
# First we will compute the tangent direction (see _get_tangent for details),
Expand All @@ -1228,12 +1236,18 @@

# we don't need to divide this part into blocked and batched because
# self._constraint._deriv_mode will handle it
jvpfun = lambda u: self._get_tangent(u, xf, constants, op=op)
tangents = batched_vectorize(
jvpfun,
signature="(n)->(k)",
chunk_size=self._constraint._jac_chunk_size,
)(v)
tangents = _proximal_get_tangents(
self._constraint,
xf,
v,
constants[1],
self._eq_solve_objective._feasible_tangents,
self._dxdc,
self._feasible_tangents,
self._dimc_per_thing,
self._eq_idx,
op,
)

if self._objective._deriv_mode == "batched":
# objective's method already know about its jac_chunk_size
Expand All @@ -1246,53 +1260,6 @@
op,
)

def _get_tangent(self, v, xf, constants, op):
# Note: This function is vectorized over v. So, v is expected to be 1D array
# of size self.dim_x.

# v contains self._args DoFs from eq and other objects (like coils, surfaces
# etc), we want jvp_f to only get parts from equilibrium, not other things
vs = jnp.split(v, np.cumsum(self._dimc_per_thing))
# This is (dF/dx)^-1 * dF/dc # noqa : E800
dfdc = _proximal_jvp_f_pure(
self._constraint,
xf,
constants[1],
vs[self._eq_idx],
self._eq_solve_objective._feasible_tangents,
self._dxdc,
op,
)
# broadcasting against multiple things
dfdcs = [jnp.zeros(dim) for dim in self._dimc_per_thing]
dfdcs[self._eq_idx] = dfdc
# note that dfdc.size != vs[self._eq_idx].size
# dfdc has the size of reduced state vector of the equilibrium
# but vs[self._eq_idx] has the size of self._args DoFs
dfdc = jnp.concatenate(dfdcs)

# We try to find dG/dc - dG/dx * (dF/dx)^-1 * dF/dc
# where G is the objective function. Since DESC stores x and c in the same
# vector, instead of multiple JVP calls, we will just find a tangent direction
# that will give us the same result.
# For making the explanation clear, assume J is the Jacobian of the objective
# function with respect to the full state vector (both x and c). Then,
# dG/dc = J @ (tangent vectors in c direction)
# dG/dx = J @ (tangent vectors in x direction)
# So, dG/dc - dG/dx * (dF/dx)^-1 * dF/dc can be written as
# J @ [(tangent vectors in c direction) - (tangent vectors in x direction)@dfdc]
# Note: We will never form full Jacobian J, we will just compute the above
# expression by JVPs.
dxdcv = jnp.concatenate(
[
*vs[: self._eq_idx],
self._dxdc @ vs[self._eq_idx], # Rb_lmn, Zb_lmn to full eq state vector
*vs[self._eq_idx + 1 :],
]
)
tangent = dxdcv - self._feasible_tangents @ dfdc
return tangent

@property
def constants(self):
"""list: constant parameters for each sub-objective."""
Expand All @@ -1318,9 +1285,11 @@
# define these helper functions that are stateless so we can safely jit them


def jit_if_possible(func):
def jit_if_possible(func=None, *, static_argnames=("op",)):
"""Jit a function if use_jit."""
jitted_func = functools.partial(jit, static_argnames=["op"])(func)
if func is None:
return functools.partial(jit_if_possible, static_argnames=static_argnames)
jitted_func = functools.partial(jit, static_argnames=list(static_argnames))(func)

@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand All @@ -1334,15 +1303,103 @@
return wrapper


@jit_if_possible
@jit_if_possible(static_argnames=("op", "dimc_per_thing", "eq_idx"))
def _proximal_get_tangents(
constraint,
xf,
v,
constants,
eq_feasible_tangents,
dxdc,
feasible_tangents,
dimc_per_thing,
eq_idx,
op="scaled_error",
):
jvpfun = lambda u: _get_tangent(
constraint,
u,
xf,
constants,
eq_feasible_tangents,
dxdc,
feasible_tangents,
dimc_per_thing,
eq_idx,
op,
)
return batched_vectorize(
jvpfun,
signature="(n)->(k)",
chunk_size=constraint._jac_chunk_size,
)(v)


def _get_tangent(
constraint,
v,
xf,
constants,
eq_feasible_tangents,
dxdc,
feasible_tangents,
dimc_per_thing,
eq_idx,
op,
):
# Note: This function is vectorized over v. So, v is expected to be 1D array
# of size prox.dim_x.

# v contains prox._args DoFs from eq and other objects (like coils, surfaces
# etc), we want jvp_f to only get parts from equilibrium, not other things
vs = jnp.split(v, np.cumsum(dimc_per_thing))
# This is (dF/dx)⁻¹ * dF/dc # noqa : E800
dfdc = _proximal_jvp_f_pure(
constraint, xf, constants, vs[eq_idx], eq_feasible_tangents, dxdc, op
)
# broadcasting against multiple things
dfdcs = [jnp.zeros(dim) for dim in dimc_per_thing]
dfdcs[eq_idx] = dfdc
# note that dfdc.size != vs[eq_idx].size
# dfdc has the size of reduced state vector of the equilibrium
# but vs[eq_idx] has the size of prox._args DoFs
dfdc = jnp.concatenate(dfdcs)

# We try to find dG/dc - dG/dx * (dF/dx)⁻¹ * dF/dc
# where G is the objective function. Since DESC stores x and c in the same
# vector, instead of multiple JVP calls, we will just find a tangent direction
# that will give us the same result.
# For making the explanation clear, assume J is the Jacobian of the objective
# function with respect to the full state vector (both x and c). Then,
# dG/dc = J @ (tangent vectors in c direction)
# dG/dx = J @ (tangent vectors in x direction)
# So, dG/dc - dG/dx * (dF/dx)⁻¹ * dF/dc can be written as
# J @ [(tangent vectors in c direction) - (tangent vectors in x direction)@dfdc]
# Note: We will never form full Jacobian J, we will just compute the above
# expression by JVPs.
dxdcv = jnp.concatenate(
[
*vs[:eq_idx],
dxdc @ vs[eq_idx], # Rb_lmn, Zb_lmn to full eq state vector
*vs[eq_idx + 1 :],
]
)
tangent = dxdcv - feasible_tangents @ dfdc
return tangent


def _proximal_jvp_f_pure(constraint, xf, constants, dc, eq_feasible_tangents, dxdc, op):
# Note: This function is called by _get_tangent which is vectorized over v
# (v is called dc in this function). So, dc is expected to be 1D array
# of same size as full equilibrium state vector. This function returns a 1D array.

# here we are forming (dF/dx)^-1 @ dF/dc
# where Fxh is dF/dx and Fc is dF/dc
Fxh = getattr(constraint, "jvp_" + op)(eq_feasible_tangents.T, xf, constants).T
# here we are forming (dF/dx)⁻¹ @ dF/dc
# where Fxh is dF/dxᵀ and Fc is dF/dc.
# Note: Fxh and its SVD do not depend on dc (the vectorized argument). Since the
# whole tangent computation is jitted as one program, we rely on the compiler to
# hoist this loop-invariant SVD out of the batched scan/vmap rather than
# recomputing it for every tangent.
Fxh = getattr(constraint, "jvp_" + op)(eq_feasible_tangents.T, xf, constants)
# Our compute functions never include variables like Rb_lmn, Zb_lmn etc. So,
# taking the JVP in just dc direction will give 0. To prevent this, we use dxdc
# which is the dx/dc matrix and convert the Rb_lmn to R_lmn entries etc.
Expand All @@ -1354,7 +1411,7 @@
uf, sf, vtf = jnp.linalg.svd(Fxh, full_matrices=False)
sf += sf[-1] # add a tiny bit of regularization
sfi = jnp.where(sf < cutoff * sf[0], 0, 1 / sf)
return vtf.T @ (sfi * (uf.T @ Fc))
return uf @ (sfi * (vtf @ Fc))


@jit_if_possible
Expand Down
23 changes: 23 additions & 0 deletions tests/benchmarks/benchmark_cpu_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,29 @@ def run(x, prox):
benchmark.pedantic(run, args=(x, prox), rounds=20, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_proximal_jac_atf_chunked(benchmark):
"""Benchmark computing jacobian of constrained proximal projection."""
eq = desc.examples.get("ATF")
grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=np.linspace(0.1, 1, 10))
objective = ObjectiveFunction(QuasisymmetryTwoTerm(eq, grid=grid))
# chunk the computation, total size is 252, so this should take at most
# 2.5x the unchunked case above
constraint = ObjectiveFunction(ForceBalance(eq), jac_chunk_size=100)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
prox.build()
x = prox.x(eq)
prox.jac_scaled_error(x).block_until_ready()

def run(x, prox):
prox.jac_scaled_error(x).block_until_ready()

benchmark.pedantic(run, args=(x, prox), rounds=10, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_proximal_jac_atf_with_eq_update(benchmark):
Expand Down
23 changes: 23 additions & 0 deletions tests/benchmarks/benchmark_gpu_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,29 @@ def run(x, prox):
benchmark.pedantic(run, args=(x, prox), rounds=20, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_proximal_jac_atf_chunked(benchmark):
"""Benchmark computing jacobian of constrained proximal projection."""
eq = desc.examples.get("ATF")
grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=np.linspace(0.1, 1, 10))
objective = ObjectiveFunction(QuasisymmetryTwoTerm(eq, grid=grid))
# chunk the computation, total size is 252, so this should take at most
# 2.5x the unchunked case above
constraint = ObjectiveFunction(ForceBalance(eq), jac_chunk_size=100)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
prox.build()
x = prox.x(eq)
prox.jac_scaled_error(x).block_until_ready()

def run(x, prox):
prox.jac_scaled_error(x).block_until_ready()

benchmark.pedantic(run, args=(x, prox), rounds=10, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_proximal_jac_atf_with_eq_update(benchmark):
Expand Down
Loading