Surface silent convergence and optimizer failures during fitting#30
Conversation
|
Pushed two follow-ups documented in the PR description:
|
Two paired optimizer-surfacing fixes that make previously-silent fitting failures visible to the user. - likelihoods.sorted_spikes_glm.fit_poisson_regression: BFGS result.success was discarded after every call. Place-field coefficients were silently set to wherever the optimizer happened to be when it gave up. Emit UserWarning with the BFGS message, iteration count, and final loss when not res.success. - likelihoods.gmm.GaussianMixtureModel: converged_ was set as the heuristic n_iter < max_iter, ignoring the real lower-bound delta convergence flag returned by _em_fit_while_loop. The plumbing discarded the flag at _fit_single. Thread the real converged through _fit_single's return tuple (now 6 elements) up to fit; emit UserWarning at end of fit when not converged. Initialize converged_=False and n_iter_=0 before the n_init loop so the warning machinery is robust to the all-inits-fail case. Tests for each warning path: monkeypatched BFGS-failure for the GLM (real non-convergence depends on SciPy version and design conditioning, so a deterministic monkeypatch is the right shape for this single-call site); max_iter=1 + tol=1e-20 for the GMM. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Make previously-silent _DetectorBase EM failures visible. - Capture both return values of check_converged. When is_increasing is False, log a warning describing iteration, before/after log-likelihoods, and change magnitude; append the iteration to a new em_monotonicity_violations_ list (initialized empty before the loop so re-fits reset cleanly). - After the EM while loop exits, set converged_ and n_iter_ unconditionally. When not converged_, emit a UserWarning with max_iter, the final log-likelihood change, and the tolerance. - The final E-step now checks that the post-M-step log-likelihood did not decrease by more than tolerance; a UserWarning flags inconsistencies between the E-step and the M-step output that the prior code would have silently propagated through the returned posterior. - Document the three new attributes on _DetectorBase (inherited by every detector subclass): converged_, n_iter_, em_monotonicity_violations_. Tests: monkeypatched check_converged to force a violation on iteration 2; max_iter=1, tolerance=1e-20 to force a max-iter exit; well-behaved-fit-emits-no-warning negative case using warnings.catch_warnings(record=True) filtered by message substring (tolerant of Groups 2a/2b warnings unrelated to EM monotonicity). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When every state has -inf log-likelihood at a timestep, _condition_on previously returned an invalid all-zero posterior (probs * exp(-inf) = 0, then divided by eps in _normalize) and propagated log_norm=-inf silently through the scan. Replace with a fall-back-to-predicted policy plus a host-side diagnostic count. - _condition_on now uses jnp.where(ll_is_finite, ...) to return new_probs_normal on the normal path and probs unchanged on the all-impossible path, with log_norm=-inf marking the degenerate step in the marginal log-likelihood. Numerically byte-identical to the prior implementation on well-behaved inputs (verified by hand-derivation test plus golden regression). - Add _count_degenerate_timesteps and _warn_if_degenerate_timesteps helpers in core.py. Each public filter (filter, filter_covariate_dependent, chunked_filter_smoother, chunked_filter_smoother_covariate_dependent) now emits a single logger.warning per invocation when any degenerate timesteps are detected. The chunked drivers aggregate the count across chunks and warn once after the forward pass. - logger.warning (not warnings.warn) because the filter runs many times per fit and warnings.warn dedup would suppress legitimate signal across iterations; logger.warning integrates with the existing per-module logger configuration. Tests cover: all-inf input falls back to predicted (probs and log_norm); partial-inf still normalizes correctly; well-behaved input matches legacy normalization byte-for-byte; filter on a 3-timestep input with middle timestep all-inf emits the warning and returns predicted-distribution posterior at that step. CHANGELOG.md merges Phase 1 Changed entries with Phase 2's new Added/Changed bullets and notes the GLM EM convergence behavior surfaced by these changes as a follow-up investigation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
check_converged's is_increasing return value used a hardcoded -1e-3 slack, while the EM final-E-step consistency warning compares the change against the caller-supplied tolerance. For tight tolerances (<<1e-3) the final-E-step warning would fire while no in-loop warning did; for loose tolerances the in-loop check would warn at decreases the final-E-step check considered noise. Thread tolerance into the is_increasing comparison so both checks use the same slack. The default tolerance=1e-4 makes is_increasing slightly stricter than the prior -1e-3 (decreases of 1e-4 to 1e-3 that were previously treated as noise now flag). This is the intent of the surrounding Phase 2 work — surface previously-silent inconsistencies, not normalize them away. CHANGELOG points at issue #31 for the open follow-up on the GLM EM M-step (which uses the Local-only marginal as weights for an emission shared across states, breaking EM monotonicity in a way that this tolerance change does not address). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Critical: _condition_on detected the all-impossible fallback via ~isfinite(ll.max()), which also catches NaN. A NaN in one state was silently replaced by the predicted prior (discarding valid evidence and mislabeling the step 'all-impossible') instead of propagating as a visible NaN. The fallback now triggers only on ll.max() == -inf, so a finite state beside a -inf is still used and a NaN flows through to the posterior. _count_degenerate_timesteps counts only all--inf steps; a new _count_nan_timesteps drives a separate NaN warning so the two conditions are reported distinctly. The chunked drivers tally both. Important: check_converged's is_increasing monotonicity check is now relative ((curr-prev)/avg >= -tolerance), matching is_converged, instead of an absolute -1e-3 slack that spuriously flagged negligible floating-point decreases on large (sum-over-time) log-likelihoods. Well-behaved path is byte-identical (golden + snapshot, no diffs). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- Monotonicity violations now emit warnings.warn(UserWarning), matching the other EM/GMM/GLM fit-quality diagnostics (decode-time degenerate/ NaN-timestep warnings stay logger.warning by design; documented). - The final-E-step consistency check reuses check_converged so it shares the in-loop relative monotonicity slack (was an absolute -tolerance). - The four EM-result attributes (converged_, n_iter_, em_monotonicity_violations_, and the new degenerate_timesteps_) are now initialized together after input validation, so a fit that fails validation leaves none of them set (all-or-nothing). - degenerate_timesteps_ exposes the 0-based time indices where the final E-step fell back to the prior, so the condition is programmatically inspectable, not only logged. Documented in the class docstring and the fit-attribute annotation block. - Breadcrumb at the encoding-update site pointing to issue #31 (Local-only weights on a shared emission can break EM monotonicity). Adds tests: monotonicity UserWarning, final-E-step inconsistency trigger, and all-or-nothing attribute initialization. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- GMM converged_ is now final_delta <= tol. The old final_i < max_iter
AND final_delta <= tol reported a fit that reached tolerance on the
last allowed iteration (final_i == max_iter) as non-converged, which
the new fit-time warning then surfaced as a false positive.
- fit_poisson_regression judges convergence by the final gradient
inf-norm (> GLM_CONVERGENCE_GRAD_TOL = 1e-3) instead of SciPy's
success flag. BFGS routinely returns success=False ('precision loss')
at a good minimum, so the prior warning fired near-constantly; gating
on the gradient warns only on genuine non-convergence.
- GMM non-convergence warning now reports the selected restart's actual
iteration count.
Adds tests: GMM converged_=True happy path and max_iter-boundary case
(RED against the old heuristic), GLM large-gradient warns / benign
precision-loss does not.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
a696ce9 to
c275c35
Compare
Two follow-up review findings on the EM-surfacing changes: 1. The host-side degenerate/NaN diagnostics in filter / filter_covariate_ dependent did int(jnp.sum(...)), so jax.jit(filter)(...) (or calling the filter inside an outer jit/vmap/scan) raised ConcretizationTypeError -- a regression from main, where filter was a pure jitted function. _warn_if_degenerate_timesteps now skips when its input is a jax.core.Tracer (diagnostics can't run on traced values and warnings can't be emitted during tracing); the pure filtering computation still runs. Restores composability. 2. degenerate_timesteps_ was computed from the final E-step's returned log-likelihoods, which _predict returns as None for n_chunks > 1 (caching is disabled), leaving the attribute silently empty while the count warning still fired. The chunked drivers now accept an optional degenerate_indices_out collector and append the global all--inf indices during the forward pass (where per-chunk likelihoods are seen), so the attribute is reliable regardless of n_chunks/caching. _predict threads the collector and stores the result on self; its public return arity is unchanged. Tests: jax.jit composability for both filter wrappers; collector recovers the correct global indices on the uncached path (returned log_likelihoods is None). Byte-identical on well-behaved inputs (golden + snapshot, no diffs). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Quality cleanup (no behavior change; byte-identical on golden + snapshot): - core.py: replace the three separate-pass helpers (_count_degenerate_timesteps / _count_nan_timesteps / _degenerate_timestep_indices) with one _degenerate_and_nan_masks (single max(axis=-1)/isnan pass, both masks pulled to host once) plus a shared _accumulate_chunk_degeneracy. Collapses up to 3 device->host syncs per chunk to 1 and de-duplicates the identical tally/collect block that was copy-pasted across both chunked drivers. _warn_if_degenerate_timesteps now single-passes too. - base.py: read self._degenerate_timesteps_ directly (the preceding final E-step's _predict always sets it) instead of a getattr with a dead default; move the final-E-step log-likelihood delta into the warning branch that uses it. - test_em_monotonicity.py: extract a sim fixture for the make_simulated_run_data setup that was duplicated five times across three classes. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
a6728a8 to
620c1c3
Compare
Low-severity follow-ups from the full-PR review (no production logic change): - core.py: correct the _degenerate_and_nan_masks docstring / chunk comments that overstated the host transfer as a single one -- it is two small reductions pulled at one host sync point per call (down from three separate helper calls), not literally one transfer. - models/base.py: note in the degenerate_timesteps_ docstring that predict() does not refresh the attribute (it still logs a degenerate-timestep warning for its data); the attribute reflects the training data set by estimate_parameters. - test_em_monotonicity.py: assert degenerate_timesteps_ is an empty ndarray after a well-behaved fit, pinning the _predict -> fitted-attribute wiring end to end (previously only the core-level collector was tested). - CHANGELOG: note that filter / filter_covariate_dependent are now plain wrappers rather than jax.jit objects (semi-public module-level functions). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Summary
Make previously-silent fitting failures visible across the optimization stack:
likelihoods.sorted_spikes_glm.fit_poisson_regression):result.successwas discarded; place-field coefficients silently set to wherever the optimizer gave up. Now emits aUserWarningwhen not converged.likelihoods.gmm.GaussianMixtureModel):converged_was set as the heuristicn_iter < max_iter, ignoring the real lower-bound delta convergence flag. Now plumbs the real flag from_em_fit_while_loopthrough_fit_singletofit; emits aUserWarningon non-convergence._DetectorBaseEM (models.base): monotonicity violations, max-iter exits, and final-E-step inconsistencies were all silent. Now log warnings and record violation iteration indices in a newem_monotonicity_violations_attribute.core._condition_on): all--inflog-likelihoods at a timestep silently produced an all-zero posterior. Now falls back to the predicted distribution and setslog_norm=-infto mark the degenerate step; each public filter (filter,filter_covariate_dependent,chunked_filter_smoother,chunked_filter_smoother_covariate_dependent) emits a singlelogger.warningper invocation summarizing the degenerate count.converged_,n_iter_,em_monotonicity_violations_.Numerically byte-identical to the prior implementation on well-behaved inputs (verified by hand-derivation test plus golden regression).
Stacked on
PR #29 (
correctness-bug-bundle). The base of this PR iscorrectness-bug-bundle; when #29 merges, this PR's base will rebase ontomainautomatically.Test plan
_condition_onall-inf, partial-inf, all-finite legacy regression, filter degenerate warning).-m "not slow"— 817 pass, 3 skipped, 1 xfailed (all pre-existing).ruff check src/andruff format --check src/— clean.Known issues surfaced (not regressions)
Several pre-existing sorted-spikes-GLM tests now emit the new
UserWarningfor non-convergent BFGS exits. The fitted coefficients still satisfy the existing test assertions, but the warning indicates these tests have been silently fitting non-converged models.test_sorted_spikes_glm_encoding_runs_end_to_endadditionally trips the new max-iter and final-E-step warnings under default parameters. None are Phase 2 regressions — they are previously-silent issues now visible. Investigation of GLM EM convergence behavior tracked as a follow-up.Internal-threshold note
The in-loop EM monotonicity check uses
is_increasingfromcheck_converged(hardcoded-1e-3), while the final-E-step warning uses-tolerance(user-supplied). For users with very tighttolerance(<<1e-3), the final-E-step warning may fire while no in-loop warning did. Threadingtolerancethroughcheck_convergedis a follow-up.🤖 Generated with Claude Code