Skip to content

Surface silent convergence and optimizer failures during fitting#30

Merged
edeno merged 11 commits into
mainfrom
em-convergence-surfacing
Jun 11, 2026
Merged

Surface silent convergence and optimizer failures during fitting#30
edeno merged 11 commits into
mainfrom
em-convergence-surfacing

Conversation

@edeno

@edeno edeno commented May 21, 2026

Copy link
Copy Markdown
Contributor

Summary

Make previously-silent fitting failures visible across the optimization stack:

  • GLM Poisson BFGS (likelihoods.sorted_spikes_glm.fit_poisson_regression): result.success was discarded; place-field coefficients silently set to wherever the optimizer gave up. Now emits a UserWarning when not converged.
  • GMM EM (likelihoods.gmm.GaussianMixtureModel): converged_ was set as the heuristic n_iter < max_iter, ignoring the real lower-bound delta convergence flag. Now plumbs the real flag from _em_fit_while_loop through _fit_single to fit; emits a UserWarning on non-convergence.
  • _DetectorBase EM (models.base): monotonicity violations, max-iter exits, and final-E-step inconsistencies were all silent. Now log warnings and record violation iteration indices in a new em_monotonicity_violations_ attribute.
  • HMM filtering (core._condition_on): all--inf log-likelihoods at a timestep silently produced an all-zero posterior. Now falls back to the predicted distribution and sets log_norm=-inf to mark the degenerate step; each public filter (filter, filter_covariate_dependent, chunked_filter_smoother, chunked_filter_smoother_covariate_dependent) emits a single logger.warning per invocation summarizing the degenerate count.
  • New fitted attributes on every detector: converged_, n_iter_, em_monotonicity_violations_.

Numerically byte-identical to the prior implementation on well-behaved inputs (verified by hand-derivation test plus golden regression).

Stacked on

PR #29 (correctness-bug-bundle). The base of this PR is correctness-bug-bundle; when #29 merges, this PR's base will rebase onto main automatically.

Test plan

  • New tests for each surfacing path (GLM BFGS, GMM EM, EM monotonicity violation, max-iter exit, well-behaved no-warning, _condition_on all-inf, partial-inf, all-finite legacy regression, filter degenerate warning).
  • Snapshot suite (8 tests) — pass, no diffs.
  • Golden regression suite (4 tests) — pass, no diffs.
  • Broad suite -m "not slow" — 817 pass, 3 skipped, 1 xfailed (all pre-existing).
  • ruff check src/ and ruff format --check src/ — clean.

Known issues surfaced (not regressions)

Several pre-existing sorted-spikes-GLM tests now emit the new UserWarning for non-convergent BFGS exits. The fitted coefficients still satisfy the existing test assertions, but the warning indicates these tests have been silently fitting non-converged models. test_sorted_spikes_glm_encoding_runs_end_to_end additionally trips the new max-iter and final-E-step warnings under default parameters. None are Phase 2 regressions — they are previously-silent issues now visible. Investigation of GLM EM convergence behavior tracked as a follow-up.

Internal-threshold note

The in-loop EM monotonicity check uses is_increasing from check_converged (hardcoded -1e-3), while the final-E-step warning uses -tolerance (user-supplied). For users with very tight tolerance (<<1e-3), the final-E-step warning may fire while no in-loop warning did. Threading tolerance through check_converged is a follow-up.

🤖 Generated with Claude Code

@edeno

edeno commented May 21, 2026

Copy link
Copy Markdown
Contributor Author

Pushed two follow-ups documented in the PR description:

  1. Tolerance unification (commit a696ce9): check_converged's is_increasing slack now uses the caller-supplied tolerance instead of a hardcoded -1e-3. In-loop and final-E-step checks now share one threshold. 46 affected tests stay green.

  2. GLM EM non-convergence root cause: investigated and documented in GLM EM M-step uses Local-only marginal for shared-state emission, breaking EM monotonicity #31. The M-step uses the Local-state marginal as weights for an emission shared across all states, violating EM's monotonicity guarantee on the joint posterior. Three fix options scoped (smallest: aggregate weights across all sharing states; largest: redesign M-step for joint state/bin integration). Not addressed in this PR — the warnings correctly surface the bug and GLM EM M-step uses Local-only marginal for shared-state emission, breaking EM monotonicity #31 tracks the fix.

edeno and others added 8 commits June 11, 2026 09:32
Two paired optimizer-surfacing fixes that make previously-silent
fitting failures visible to the user.

- likelihoods.sorted_spikes_glm.fit_poisson_regression: BFGS
  result.success was discarded after every call. Place-field
  coefficients were silently set to wherever the optimizer happened
  to be when it gave up. Emit UserWarning with the BFGS message,
  iteration count, and final loss when not res.success.

- likelihoods.gmm.GaussianMixtureModel: converged_ was set as the
  heuristic n_iter < max_iter, ignoring the real lower-bound delta
  convergence flag returned by _em_fit_while_loop. The plumbing
  discarded the flag at _fit_single. Thread the real converged
  through _fit_single's return tuple (now 6 elements) up to fit;
  emit UserWarning at end of fit when not converged. Initialize
  converged_=False and n_iter_=0 before the n_init loop so the
  warning machinery is robust to the all-inits-fail case.

Tests for each warning path: monkeypatched BFGS-failure for the GLM
(real non-convergence depends on SciPy version and design
conditioning, so a deterministic monkeypatch is the right shape for
this single-call site); max_iter=1 + tol=1e-20 for the GMM.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Make previously-silent _DetectorBase EM failures visible.

- Capture both return values of check_converged. When is_increasing
  is False, log a warning describing iteration, before/after
  log-likelihoods, and change magnitude; append the iteration to a
  new em_monotonicity_violations_ list (initialized empty before the
  loop so re-fits reset cleanly).

- After the EM while loop exits, set converged_ and n_iter_
  unconditionally. When not converged_, emit a UserWarning with
  max_iter, the final log-likelihood change, and the tolerance.

- The final E-step now checks that the post-M-step log-likelihood
  did not decrease by more than tolerance; a UserWarning flags
  inconsistencies between the E-step and the M-step output that the
  prior code would have silently propagated through the returned
  posterior.

- Document the three new attributes on _DetectorBase (inherited by
  every detector subclass): converged_, n_iter_,
  em_monotonicity_violations_.

Tests: monkeypatched check_converged to force a violation on
iteration 2; max_iter=1, tolerance=1e-20 to force a max-iter exit;
well-behaved-fit-emits-no-warning negative case using
warnings.catch_warnings(record=True) filtered by message substring
(tolerant of Groups 2a/2b warnings unrelated to EM monotonicity).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When every state has -inf log-likelihood at a timestep, _condition_on
previously returned an invalid all-zero posterior (probs * exp(-inf)
= 0, then divided by eps in _normalize) and propagated log_norm=-inf
silently through the scan. Replace with a fall-back-to-predicted
policy plus a host-side diagnostic count.

- _condition_on now uses jnp.where(ll_is_finite, ...) to return
  new_probs_normal on the normal path and probs unchanged on the
  all-impossible path, with log_norm=-inf marking the degenerate
  step in the marginal log-likelihood. Numerically byte-identical to
  the prior implementation on well-behaved inputs (verified by
  hand-derivation test plus golden regression).

- Add _count_degenerate_timesteps and _warn_if_degenerate_timesteps
  helpers in core.py. Each public filter (filter,
  filter_covariate_dependent, chunked_filter_smoother,
  chunked_filter_smoother_covariate_dependent) now emits a single
  logger.warning per invocation when any degenerate timesteps are
  detected. The chunked drivers aggregate the count across chunks
  and warn once after the forward pass.

- logger.warning (not warnings.warn) because the filter runs many
  times per fit and warnings.warn dedup would suppress legitimate
  signal across iterations; logger.warning integrates with the
  existing per-module logger configuration.

Tests cover: all-inf input falls back to predicted (probs and
log_norm); partial-inf still normalizes correctly; well-behaved
input matches legacy normalization byte-for-byte; filter on a
3-timestep input with middle timestep all-inf emits the warning and
returns predicted-distribution posterior at that step.

CHANGELOG.md merges Phase 1 Changed entries with Phase 2's new
Added/Changed bullets and notes the GLM EM convergence behavior
surfaced by these changes as a follow-up investigation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
check_converged's is_increasing return value used a hardcoded -1e-3
slack, while the EM final-E-step consistency warning compares the
change against the caller-supplied tolerance. For tight tolerances
(<<1e-3) the final-E-step warning would fire while no in-loop
warning did; for loose tolerances the in-loop check would warn at
decreases the final-E-step check considered noise.

Thread tolerance into the is_increasing comparison so both checks
use the same slack. The default tolerance=1e-4 makes is_increasing
slightly stricter than the prior -1e-3 (decreases of 1e-4 to 1e-3
that were previously treated as noise now flag). This is the intent
of the surrounding Phase 2 work — surface previously-silent
inconsistencies, not normalize them away.

CHANGELOG points at issue #31 for the open follow-up on the GLM EM
M-step (which uses the Local-only marginal as weights for an
emission shared across states, breaking EM monotonicity in a way
that this tolerance change does not address).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Critical: _condition_on detected the all-impossible fallback via
~isfinite(ll.max()), which also catches NaN. A NaN in one state was
silently replaced by the predicted prior (discarding valid evidence and
mislabeling the step 'all-impossible') instead of propagating as a
visible NaN. The fallback now triggers only on ll.max() == -inf, so a
finite state beside a -inf is still used and a NaN flows through to the
posterior. _count_degenerate_timesteps counts only all--inf steps; a new
_count_nan_timesteps drives a separate NaN warning so the two conditions
are reported distinctly. The chunked drivers tally both.

Important: check_converged's is_increasing monotonicity check is now
relative ((curr-prev)/avg >= -tolerance), matching is_converged, instead
of an absolute -1e-3 slack that spuriously flagged negligible
floating-point decreases on large (sum-over-time) log-likelihoods.

Well-behaved path is byte-identical (golden + snapshot, no diffs).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- Monotonicity violations now emit warnings.warn(UserWarning), matching
  the other EM/GMM/GLM fit-quality diagnostics (decode-time degenerate/
  NaN-timestep warnings stay logger.warning by design; documented).
- The final-E-step consistency check reuses check_converged so it shares
  the in-loop relative monotonicity slack (was an absolute -tolerance).
- The four EM-result attributes (converged_, n_iter_,
  em_monotonicity_violations_, and the new degenerate_timesteps_) are now
  initialized together after input validation, so a fit that fails
  validation leaves none of them set (all-or-nothing).
- degenerate_timesteps_ exposes the 0-based time indices where the final
  E-step fell back to the prior, so the condition is programmatically
  inspectable, not only logged. Documented in the class docstring and the
  fit-attribute annotation block.
- Breadcrumb at the encoding-update site pointing to issue #31 (Local-only
  weights on a shared emission can break EM monotonicity).

Adds tests: monotonicity UserWarning, final-E-step inconsistency trigger,
and all-or-nothing attribute initialization.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- GMM converged_ is now final_delta <= tol. The old final_i < max_iter
  AND final_delta <= tol reported a fit that reached tolerance on the
  last allowed iteration (final_i == max_iter) as non-converged, which
  the new fit-time warning then surfaced as a false positive.
- fit_poisson_regression judges convergence by the final gradient
  inf-norm (> GLM_CONVERGENCE_GRAD_TOL = 1e-3) instead of SciPy's
  success flag. BFGS routinely returns success=False ('precision loss')
  at a good minimum, so the prior warning fired near-constantly; gating
  on the gradient warns only on genuine non-convergence.
- GMM non-convergence warning now reports the selected restart's actual
  iteration count.

Adds tests: GMM converged_=True happy path and max_iter-boundary case
(RED against the old heuristic), GLM large-gradient warns / benign
precision-loss does not.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@edeno edeno force-pushed the em-convergence-surfacing branch from a696ce9 to c275c35 Compare June 11, 2026 14:22
@edeno edeno changed the base branch from correctness-bug-bundle to main June 11, 2026 14:22
edeno and others added 2 commits June 11, 2026 13:07
Two follow-up review findings on the EM-surfacing changes:

1. The host-side degenerate/NaN diagnostics in filter / filter_covariate_
   dependent did int(jnp.sum(...)), so jax.jit(filter)(...) (or calling the
   filter inside an outer jit/vmap/scan) raised ConcretizationTypeError -- a
   regression from main, where filter was a pure jitted function.
   _warn_if_degenerate_timesteps now skips when its input is a jax.core.Tracer
   (diagnostics can't run on traced values and warnings can't be emitted during
   tracing); the pure filtering computation still runs. Restores composability.

2. degenerate_timesteps_ was computed from the final E-step's returned
   log-likelihoods, which _predict returns as None for n_chunks > 1 (caching
   is disabled), leaving the attribute silently empty while the count warning
   still fired. The chunked drivers now accept an optional degenerate_indices_out
   collector and append the global all--inf indices during the forward pass
   (where per-chunk likelihoods are seen), so the attribute is reliable
   regardless of n_chunks/caching. _predict threads the collector and stores the
   result on self; its public return arity is unchanged.

Tests: jax.jit composability for both filter wrappers; collector recovers the
correct global indices on the uncached path (returned log_likelihoods is None).
Byte-identical on well-behaved inputs (golden + snapshot, no diffs).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Quality cleanup (no behavior change; byte-identical on golden + snapshot):

- core.py: replace the three separate-pass helpers
  (_count_degenerate_timesteps / _count_nan_timesteps /
  _degenerate_timestep_indices) with one _degenerate_and_nan_masks (single
  max(axis=-1)/isnan pass, both masks pulled to host once) plus a shared
  _accumulate_chunk_degeneracy. Collapses up to 3 device->host syncs per chunk
  to 1 and de-duplicates the identical tally/collect block that was copy-pasted
  across both chunked drivers. _warn_if_degenerate_timesteps now single-passes
  too.
- base.py: read self._degenerate_timesteps_ directly (the preceding final
  E-step's _predict always sets it) instead of a getattr with a dead default;
  move the final-E-step log-likelihood delta into the warning branch that uses
  it.
- test_em_monotonicity.py: extract a sim fixture for the
  make_simulated_run_data setup that was duplicated five times across three
  classes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@edeno edeno force-pushed the em-convergence-surfacing branch from a6728a8 to 620c1c3 Compare June 11, 2026 18:00
Low-severity follow-ups from the full-PR review (no production logic change):

- core.py: correct the _degenerate_and_nan_masks docstring / chunk comments
  that overstated the host transfer as a single one -- it is two small
  reductions pulled at one host sync point per call (down from three separate
  helper calls), not literally one transfer.
- models/base.py: note in the degenerate_timesteps_ docstring that predict()
  does not refresh the attribute (it still logs a degenerate-timestep warning
  for its data); the attribute reflects the training data set by
  estimate_parameters.
- test_em_monotonicity.py: assert degenerate_timesteps_ is an empty ndarray
  after a well-behaved fit, pinning the _predict -> fitted-attribute wiring
  end to end (previously only the core-level collector was tested).
- CHANGELOG: note that filter / filter_covariate_dependent are now plain
  wrappers rather than jax.jit objects (semi-public module-level functions).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@edeno edeno merged commit 7854e71 into main Jun 11, 2026
9 checks passed
@edeno edeno deleted the em-convergence-surfacing branch June 11, 2026 19:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant