Skip to content

Restore hashability and sklearn.clone() contract on detectors#34

Merged
edeno merged 4 commits into
mainfrom
type-design-corrections
Jun 12, 2026
Merged

Restore hashability and sklearn.clone() contract on detectors#34
edeno merged 4 commits into
mainfrom
type-design-corrections

Conversation

@edeno

@edeno edeno commented May 21, 2026

Copy link
Copy Markdown
Contributor

Summary

Phase 4 of the type-design corrections from the post-review audit. Removes type-incorrect dunders and fixes the mutable-default / fit-mutation issues that broke sklearn.clone().

Type-correct equality

  • Environment.__eq__(self, other: str) removed. The override compared an Environment against a string, breaking the symmetric/transitive contract callers reasonably expect. The eight environments.index(name) call sites that exploited it for name lookup are converted to a new find_environment_by_name(environments, name) helper in environment.py.
  • ObservationModel.__eq__ removed. It compared only (environment_name, encoding_group), silently merging models that differed on is_local / is_no_spike. The dataclass-generated __eq__ now compares all four fields, restoring __lt__ / __eq__ consistency. User-visible: a hand-built observation_models list whose entries share (env, group) but differ on is_local / is_no_spike will now fit one encoding model per entry rather than collapsing.

Hashability

Both classes lost their auto-generated __hash__ when their custom __eq__ was defined; removing the override does not by itself restore it (a dataclass with eq=True and no explicit hash sets __hash__ = None). Restored explicitly:

  • Environment: identity hash (__hash__ = object.__hash__), since post-fit numpy / networkx attributes have no meaningful value-hash.
  • ObservationModel: value hash via @dataclass(unsafe_hash=True), since all four fields are hashable primitives.

algorithm_params and sklearn.clone()

The plan's initial approach (__init__ defensively copies the dict) fixed the mutable-default sharing bug but broke sklearn.clone() (which requires self.param == param after __init__). Refactored to:

  • Constructors store the user's argument as-is (including None); signatures are dict | None = None.
  • The dict copy + default resolution moves into two new helpers _resolve_clusterless_algorithm_params and _resolve_sorted_spikes_algorithm_params on ClusterlessDetector / SortedSpikesDetector. Each returns a fresh dict at every call, so the mutable-default-sharing bug stays fixed.
  • The two encoding-fit sites in _DetectorBase call the helpers.

User-visible: a user-supplied params dict is no longer defensively copied at construction. d.clusterless_algorithm_params is user_dict after construction (sklearn convention). Mutations the user makes between construction and fit propagate to the fit. To get isolated state, pass a fresh dict (dict(my_params)).

Fitted-attribute convention

estimate_parameters no longer overwrites self.discrete_initial_conditions (the user's constructor argument). The fitted distribution lives on self.discrete_initial_conditions_ (trailing underscore, sklearn convention). This restores get_params() correctness and lets sklearn.clone(fitted_model) return an unfitted estimator with the user's original spec.

Stacked on

PR #33 (duplicate-path-cleanup, Phase 3). When #33 merges, this PR's base rebases onto main automatically.

Test plan

  • test_find_environment_by_name, hashability and equality tests in tests/environment/test_track_graph.py.
  • ObservationModel four-field equality + set membership: tests/test_observation_models.py (new file).
  • algorithm_params resolver contract: TestAlgorithmParamsResolution in tests/models/test_base_initialization.py (10 tests covering None-default, fresh-copy, no-leak, all five subclasses, and sklearn.clone() round-trip).
  • discrete_initial_conditions_ fitted-attribute contract: tests/test_em_predict_consistency.py (3 tests covering constructor-arg preservation, fitted attribute existence, sklearn.clone() after fit).
  • ruff check src/ clean; ruff format --check src/ clean.
  • Snapshot suite (8 tests): pass, no diffs.
  • Golden regression (4 tests): pass, no diffs.
  • Broad suite -m "not slow": 828 pass, 3 skipped, 1 xfailed (all pre-existing).

🤖 Generated with Claude Code

edeno and others added 3 commits June 11, 2026 17:47
Both Environment and ObservationModel previously overrode __eq__ in
ways that silently subverted correctness, and lost hashability as a
side effect (a dataclass with a custom __eq__ but eq=True sets
__hash__ = None).

- Environment.__eq__(self, other: str) returned the result of
  self.environment_name == other. This made env_a == env_b return
  False whenever they had the same name (because the right operand
  was treated as a string), broke symmetry of ==, and made
  Environment unhashable. The override existed only to allow
  environments.index(name) to find an env by string name; replace
  every such call site with a new find_environment_by_name helper.
  Restore identity-based __hash__ (the post-fit numpy / networkx
  attributes carried by Environment have no meaningful value-hash).

- ObservationModel.__eq__ compared only (environment_name,
  encoding_group), silently merging distinct decoding states
  (is_local / is_no_spike) in set / np.unique / equality checks.
  This was inconsistent with the @DataClass(order=True)
  auto-__lt__ which always used all four fields. Remove the
  override; add unsafe_hash=True so all four fields participate in
  the hash (all are hashable primitives).

A user-visible consequence in models.base: np.unique over a list of
observation_models now treats (env, group, is_local=True) and
(env, group, is_local=False) as distinct, fitting one encoding
model per entry instead of collapsing. Default model configurations
don't trigger this; a hand-built list that mixed Local and Non-Local
observation models sharing (env, group) would.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two paired _DetectorBase refactors that bring detector internals in
line with sklearn's BaseEstimator conventions, fixing
sklearn.clone() round-tripping without losing the mutable-default
isolation guarantees.

(a) algorithm_params: lazy resolution instead of __init__-time copy.

clusterless_algorithm_params and sorted_spikes_algorithm_params now
default to None on every detector constructor and are stored as-is.
sklearn.clone() requires constructors to store params unchanged;
the previous defensive dict-copy at __init__ time violated that
contract.

The dict copy + default resolution moves into two new helper
methods on ClusterlessDetector and SortedSpikesDetector:
_resolve_clusterless_algorithm_params() and
_resolve_sorted_spikes_algorithm_params(). Each returns a fresh
dict at every call, so the mutable-default-kwarg sharing bug is
still fixed: two detectors using defaults never alias the same dict.
The two fit sites (clusterless and sorted-spikes encoding fits) now
call the helpers.

User-visible side effect: a user-supplied params dict is no longer
defensively copied at construction. d.clusterless_algorithm_params
IS the user's dict (sklearn convention). Mutations the user makes
between construction and fit will propagate. To get isolated state,
pass a fresh dict (dict(my_params)). Documented in CHANGELOG and
in each constructor's docstring.

(b) discrete_initial_conditions: fitted attribute, not param
mutation.

The EM loop previously overwrote self.discrete_initial_conditions
(the user's constructor argument) with acausal_state_probabilities[0]
on every iteration. This violated sklearn's contract that
get_params() reflects the user's spec, and broke
sklearn.clone(fitted_model). The fitted distribution now lives on
self.discrete_initial_conditions_ (trailing-underscore convention).
Initialized at the start of estimate_parameters so readers see the
fitted form post-fit. Documented in the _DetectorBase Attributes
docstring.

Tests cover both refactors: TestAlgorithmParamsResolution (10
tests) verifies the resolver contract (defaults are None, resolver
returns fresh copies, sklearn.clone() round-trips for both default
and user-supplied params); test_em_predict_consistency adds three
tests (constructor-arg preservation, fitted-attribute existence,
sklearn.clone(fitted_model) round-trip).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…g fits

Follow-up to the type-design corrections, resolving findings from the
multi-agent PR review:

- Environment: switch to @DataClass(eq=False) so equality and hashing are
  both identity-based and consistent (restores Python's a == b => hash(a)
  == hash(b) invariant). The dataclass value-__eq__ was latently broken: it
  raised ValueError comparing two fitted environments' numpy array fields, and
  pairing value-equality with the identity hash broke set/dict membership. A
  value-hash is not viable (numpy/networkx fields are unhashable); environments
  are addressed by name via find_environment_by_name, never by value-equality.

- Encoding-fit loops (ClusterlessDetector/SortedSpikesDetector): deduplicate by
  (environment_name, encoding_group) so each encoding key is fit once. The new
  4-field ObservationModel equality makes np.unique return multiple entries that
  share one key (default non-local/no-spike configs do this); the loop body only
  uses (env, group), so this removes redundant KDE/GLM fits with identical
  output. Golden + snapshot unchanged.

- Docstrings: correct the ClusterlessDetector/SortedSpikesDetector __init__
  docstrings that wrongly claimed the user algorithm_params dict is copied at
  construction (it is stored by reference; copy happens at fit time). Fix the
  "fit/predict time" inline comments to "fit time".

- CHANGELOG: correct the ObservationModel entry that claimed default configs do
  not trigger the duplicate-key path (they do); the dedup keeps the encoding-
  model count unchanged for every configuration. Update the Environment entry
  for identity equality.

- Tests: add test_encoding_model_sharing.py (default NonLocal clusterless and
  sorted each fit exactly one encoding model despite 3 unique observation
  models), the discrete_initial_conditions_ estimate_initial_conditions=False
  path, and identity eq/hash + fitted-array-no-raise tests for Environment.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@edeno edeno force-pushed the type-design-corrections branch from 6c7c59c to 30bae04 Compare June 11, 2026 22:41
Address findings from the second review pass:

- Add test_distinct_encoding_groups_fit_separate_models: fits two distinct
  encoding groups on one environment and asserts both (env, group) keys survive
  the dedup. The count-only encoding-sharing tests did not actually guard the
  dedup (the default config has a single key, so dedup-vs-not both yield one
  model); this adversarial test fails if the dedup ever collapsed on
  environment_name alone and dropped a real key — and closes a suite-wide blind
  spot where nothing exercised multi-key encoding_model_.

- environment.py: drop the redundant ``__hash__ = object.__hash__`` line.
  Under ``@dataclass(eq=False)`` both __eq__ and __hash__ already stay inherited
  from object, so the line was a no-op today and a latent footgun: a future flip
  to ``eq=True`` would silently re-introduce the value-eq / identity-hash
  mismatch the line appears to guard against. Without it, that flip instead
  fails loudly (instances become unhashable). Reword the comment accordingly.

- Strengthen test_two_environments_with_same_name_not_eq to use identical specs,
  so its inequality assertion exercises identity equality rather than a field
  difference.

Golden (4) + snapshot (8) unchanged; ruff clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@edeno edeno merged commit 36d616e into main Jun 12, 2026
10 checks passed
@edeno edeno deleted the type-design-corrections branch June 12, 2026 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant