Sparse pullback for big performance gain#2170
Conversation
Co-authored-by: Kaya Unalmis <kayaunalmis@proton.me>
Co-authored-by: Kaya Unalmis <kayaunalmis@proton.me>
| Nemov=True, | ||
| **kwargs, | ||
| ): | ||
| errorif( |
There was a problem hiding this comment.
The other issue here is that using vjp prevents using standard jax.hessian which is forward over reverse. If reverse is really always significantly faster than that may be ok, but it's worth getting other folks input.
If we do end up only supporting reverse mode it would be good to add to the Notes section of the docstring to remind users (and maybe eventually having a reverse only version of objective.hess)
| Nemov=True, | ||
| **kwargs, | ||
| ): | ||
| errorif( |
There was a problem hiding this comment.
It looks like in theory forward mode could still be possible by passing sparse=False when computing the bounce integrals? (assuming #1489 is fixed)
ddudt
left a comment
There was a problem hiding this comment.
I haven't reviewed all the code yet, but I think this introduces some breaking changes. I'm not sure what the best solution is: ideally we would keep support for forward mode, but we at least need to add more error handling.
| Nemov=True, | ||
| **kwargs, | ||
| ): | ||
| errorif( |
There was a problem hiding this comment.
Forward mode is actually very common. Individual objectives often prefer reverse mode, but typically we optimize with a list of many different objectives. If you set ObjectiveFunction(objectives, deriv_mode="batched") then it will use deriv_mode="fwd" for all of the sub-objectives. From the docs:
deriv_mode : {"auto", "batched", "blocked"}
Method for computing Jacobian matrices.batcheduses forward mode, applied
to the entire objective at once, and is generally the fastest for vector
valued objectives. ...
So unfortunately I think we do need to make this PR contingent on a solution to using forward mode, because otherwise these changes will break user code. Also I think we have been very careful to always support both forward and reverse mode for all objectives as a design principal, so at the very least we should discuss the consequences in more detail before we abandon it.
| normalize_target=True, | ||
| loss_function=None, | ||
| deriv_mode="auto", | ||
| jac_chunk_size=None, |
sparse_pullbackandsparse_pullback_mapbounce1doptimization.is_reshaped,is_fourier) that users said were confusing (backwards compatible) as well as the developer flagsBref,Lrefthat should not be there.pitch_batch_sizewas getting ignored. This fixes that by addingstrip_dim0flag tobatch_map.notes