fix: make model __call__ backend-agnostic under CasADi (#2)#39
Open
MaximilianB2 wants to merge 1 commit into
Open
fix: make model __call__ backend-agnostic under CasADi (#2)#39MaximilianB2 wants to merge 1 commit into
MaximilianB2 wants to merge 1 commit into
Conversation
The MPC oracle evaluates each model with CasADi SX symbols, which are not iterable. Models that destructured their state/input vectors via Python tuple unpacking (`a, b, c = x`) therefore raised "CasADi matrices are not iterable by design" when run with the default casadi integrator (e.g. distillation_column, cstr_series_recycle, multistage_extraction_reactive, heat_exchanger, polymerisation_reactor, invariant_batch). Switch those models to explicit indexing so the equations work under both NumPy/JAX arrays and CasADi symbolics. Also replace complex_cstr's `u.reshape(-1)[0]` with `u[0]`, which CasADi rejects. Add tests/models/test_casadi_backend.py, which calls every model with CasADi symbols of the correct shape (mirroring oracle.setup_mpc) to guard against regressions. coupled_oscillators is xfailed: it relies on numpy.concatenate over symbolic entries, a separate CasADi limitation. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes the CasADi tuple-unpacking bug (item #2 of the reported batch). Initializing models such as
distillation_columnwith the defaultcasadiintegrator raised:Cause
The MPC oracle (
oracle.setup_mpc) evaluates each model with CasADiSXsymbols, which are not iterable. Models that destructured their state/input vectors with Python tuple unpacking (a, b, c = x) worked under NumPy/JAX arrays but blew up under CasADi.Fix
Switch the affected models to explicit indexing so the equations are backend-agnostic:
distillation_columncstr_series_recyclemultistage_extraction_reactiveheat_exchangerpolymerisation_reactorinvariant_batchAlso replaced
complex_cstr'su.reshape(-1)[0]withu[0](CasADi rejects the NumPy-stylereshape(-1)).Tests
Adds
tests/models/test_casadi_backend.py, which calls every model with CasADi symbols of the correct shape — mirroring the oracle's evaluation path — and asserts the derivativesvertcatcleanly. Verified all fixed models still produce correct output under the JAX backend too.coupled_oscillatorsisxfailed: it relies onnumpy.concatenateover symbolic entries, which is a separate CasADi limitation (not tuple unpacking) and out of scope here.Notes
While investigating the reported batch, I found items #1 (
int_methodTypeError) and #3 (cstru.shape == (1,)) are already fixed onmain— every model is now a@dataclass(kw_only=True)with anint_methodfield, andcstralready usesu.shape[0] == 1. Those were likely reproduced against the older PyPI release (ties into the "PyPI is behind main" report).🤖 Generated with Claude Code