Skip to content

fix: make model __call__ backend-agnostic under CasADi (#2)#39

Open
MaximilianB2 wants to merge 1 commit into
mainfrom
fix/casadi-tuple-unpacking
Open

fix: make model __call__ backend-agnostic under CasADi (#2)#39
MaximilianB2 wants to merge 1 commit into
mainfrom
fix/casadi-tuple-unpacking

Conversation

@MaximilianB2

Copy link
Copy Markdown
Owner

Summary

Fixes the CasADi tuple-unpacking bug (item #2 of the reported batch). Initializing models such as distillation_column with the default casadi integrator raised:

Exception: CasADi matrices are not iterable by design.

Cause

The MPC oracle (oracle.setup_mpc) evaluates each model with CasADi SX symbols, which are not iterable. Models that destructured their state/input vectors with Python tuple unpacking (a, b, c = x) worked under NumPy/JAX arrays but blew up under CasADi.

Fix

Switch the affected models to explicit indexing so the equations are backend-agnostic:

  • distillation_column
  • cstr_series_recycle
  • multistage_extraction_reactive
  • heat_exchanger
  • polymerisation_reactor
  • invariant_batch

Also replaced complex_cstr's u.reshape(-1)[0] with u[0] (CasADi rejects the NumPy-style reshape(-1)).

Tests

Adds tests/models/test_casadi_backend.py, which calls every model with CasADi symbols of the correct shape — mirroring the oracle's evaluation path — and asserts the derivatives vertcat cleanly. Verified all fixed models still produce correct output under the JAX backend too.

19 passed, 1 xfailed

coupled_oscillators is xfailed: it relies on numpy.concatenate over symbolic entries, which is a separate CasADi limitation (not tuple unpacking) and out of scope here.

Notes

While investigating the reported batch, I found items #1 (int_method TypeError) and #3 (cstr u.shape == (1,)) are already fixed on main — every model is now a @dataclass(kw_only=True) with an int_method field, and cstr already uses u.shape[0] == 1. Those were likely reproduced against the older PyPI release (ties into the "PyPI is behind main" report).

🤖 Generated with Claude Code

The MPC oracle evaluates each model with CasADi SX symbols, which are not
iterable. Models that destructured their state/input vectors via Python
tuple unpacking (`a, b, c = x`) therefore raised "CasADi matrices are not
iterable by design" when run with the default casadi integrator (e.g.
distillation_column, cstr_series_recycle, multistage_extraction_reactive,
heat_exchanger, polymerisation_reactor, invariant_batch).

Switch those models to explicit indexing so the equations work under both
NumPy/JAX arrays and CasADi symbolics. Also replace complex_cstr's
`u.reshape(-1)[0]` with `u[0]`, which CasADi rejects.

Add tests/models/test_casadi_backend.py, which calls every model with
CasADi symbols of the correct shape (mirroring oracle.setup_mpc) to guard
against regressions. coupled_oscillators is xfailed: it relies on
numpy.concatenate over symbolic entries, a separate CasADi limitation.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant