Restore hashability and sklearn.clone() contract on detectors#34
Merged
Conversation
3262556 to
23c1c5a
Compare
Both Environment and ObservationModel previously overrode __eq__ in ways that silently subverted correctness, and lost hashability as a side effect (a dataclass with a custom __eq__ but eq=True sets __hash__ = None). - Environment.__eq__(self, other: str) returned the result of self.environment_name == other. This made env_a == env_b return False whenever they had the same name (because the right operand was treated as a string), broke symmetry of ==, and made Environment unhashable. The override existed only to allow environments.index(name) to find an env by string name; replace every such call site with a new find_environment_by_name helper. Restore identity-based __hash__ (the post-fit numpy / networkx attributes carried by Environment have no meaningful value-hash). - ObservationModel.__eq__ compared only (environment_name, encoding_group), silently merging distinct decoding states (is_local / is_no_spike) in set / np.unique / equality checks. This was inconsistent with the @DataClass(order=True) auto-__lt__ which always used all four fields. Remove the override; add unsafe_hash=True so all four fields participate in the hash (all are hashable primitives). A user-visible consequence in models.base: np.unique over a list of observation_models now treats (env, group, is_local=True) and (env, group, is_local=False) as distinct, fitting one encoding model per entry instead of collapsing. Default model configurations don't trigger this; a hand-built list that mixed Local and Non-Local observation models sharing (env, group) would. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two paired _DetectorBase refactors that bring detector internals in line with sklearn's BaseEstimator conventions, fixing sklearn.clone() round-tripping without losing the mutable-default isolation guarantees. (a) algorithm_params: lazy resolution instead of __init__-time copy. clusterless_algorithm_params and sorted_spikes_algorithm_params now default to None on every detector constructor and are stored as-is. sklearn.clone() requires constructors to store params unchanged; the previous defensive dict-copy at __init__ time violated that contract. The dict copy + default resolution moves into two new helper methods on ClusterlessDetector and SortedSpikesDetector: _resolve_clusterless_algorithm_params() and _resolve_sorted_spikes_algorithm_params(). Each returns a fresh dict at every call, so the mutable-default-kwarg sharing bug is still fixed: two detectors using defaults never alias the same dict. The two fit sites (clusterless and sorted-spikes encoding fits) now call the helpers. User-visible side effect: a user-supplied params dict is no longer defensively copied at construction. d.clusterless_algorithm_params IS the user's dict (sklearn convention). Mutations the user makes between construction and fit will propagate. To get isolated state, pass a fresh dict (dict(my_params)). Documented in CHANGELOG and in each constructor's docstring. (b) discrete_initial_conditions: fitted attribute, not param mutation. The EM loop previously overwrote self.discrete_initial_conditions (the user's constructor argument) with acausal_state_probabilities[0] on every iteration. This violated sklearn's contract that get_params() reflects the user's spec, and broke sklearn.clone(fitted_model). The fitted distribution now lives on self.discrete_initial_conditions_ (trailing-underscore convention). Initialized at the start of estimate_parameters so readers see the fitted form post-fit. Documented in the _DetectorBase Attributes docstring. Tests cover both refactors: TestAlgorithmParamsResolution (10 tests) verifies the resolver contract (defaults are None, resolver returns fresh copies, sklearn.clone() round-trips for both default and user-supplied params); test_em_predict_consistency adds three tests (constructor-arg preservation, fitted-attribute existence, sklearn.clone(fitted_model) round-trip). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…g fits Follow-up to the type-design corrections, resolving findings from the multi-agent PR review: - Environment: switch to @DataClass(eq=False) so equality and hashing are both identity-based and consistent (restores Python's a == b => hash(a) == hash(b) invariant). The dataclass value-__eq__ was latently broken: it raised ValueError comparing two fitted environments' numpy array fields, and pairing value-equality with the identity hash broke set/dict membership. A value-hash is not viable (numpy/networkx fields are unhashable); environments are addressed by name via find_environment_by_name, never by value-equality. - Encoding-fit loops (ClusterlessDetector/SortedSpikesDetector): deduplicate by (environment_name, encoding_group) so each encoding key is fit once. The new 4-field ObservationModel equality makes np.unique return multiple entries that share one key (default non-local/no-spike configs do this); the loop body only uses (env, group), so this removes redundant KDE/GLM fits with identical output. Golden + snapshot unchanged. - Docstrings: correct the ClusterlessDetector/SortedSpikesDetector __init__ docstrings that wrongly claimed the user algorithm_params dict is copied at construction (it is stored by reference; copy happens at fit time). Fix the "fit/predict time" inline comments to "fit time". - CHANGELOG: correct the ObservationModel entry that claimed default configs do not trigger the duplicate-key path (they do); the dedup keeps the encoding- model count unchanged for every configuration. Update the Environment entry for identity equality. - Tests: add test_encoding_model_sharing.py (default NonLocal clusterless and sorted each fit exactly one encoding model despite 3 unique observation models), the discrete_initial_conditions_ estimate_initial_conditions=False path, and identity eq/hash + fitted-array-no-raise tests for Environment. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
6c7c59c to
30bae04
Compare
Address findings from the second review pass: - Add test_distinct_encoding_groups_fit_separate_models: fits two distinct encoding groups on one environment and asserts both (env, group) keys survive the dedup. The count-only encoding-sharing tests did not actually guard the dedup (the default config has a single key, so dedup-vs-not both yield one model); this adversarial test fails if the dedup ever collapsed on environment_name alone and dropped a real key — and closes a suite-wide blind spot where nothing exercised multi-key encoding_model_. - environment.py: drop the redundant ``__hash__ = object.__hash__`` line. Under ``@dataclass(eq=False)`` both __eq__ and __hash__ already stay inherited from object, so the line was a no-op today and a latent footgun: a future flip to ``eq=True`` would silently re-introduce the value-eq / identity-hash mismatch the line appears to guard against. Without it, that flip instead fails loudly (instances become unhashable). Reword the comment accordingly. - Strengthen test_two_environments_with_same_name_not_eq to use identical specs, so its inequality assertion exercises identity equality rather than a field difference. Golden (4) + snapshot (8) unchanged; ruff clean. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Phase 4 of the type-design corrections from the post-review audit. Removes type-incorrect dunders and fixes the mutable-default / fit-mutation issues that broke
sklearn.clone().Type-correct equality
Environment.__eq__(self, other: str)removed. The override compared anEnvironmentagainst a string, breaking the symmetric/transitive contract callers reasonably expect. The eightenvironments.index(name)call sites that exploited it for name lookup are converted to a newfind_environment_by_name(environments, name)helper inenvironment.py.ObservationModel.__eq__removed. It compared only(environment_name, encoding_group), silently merging models that differed onis_local/is_no_spike. The dataclass-generated__eq__now compares all four fields, restoring__lt__/__eq__consistency. User-visible: a hand-builtobservation_modelslist whose entries share(env, group)but differ onis_local/is_no_spikewill now fit one encoding model per entry rather than collapsing.Hashability
Both classes lost their auto-generated
__hash__when their custom__eq__was defined; removing the override does not by itself restore it (a dataclass witheq=Trueand no explicit hash sets__hash__ = None). Restored explicitly:Environment: identity hash (__hash__ = object.__hash__), since post-fit numpy / networkx attributes have no meaningful value-hash.ObservationModel: value hash via@dataclass(unsafe_hash=True), since all four fields are hashable primitives.algorithm_paramsand sklearn.clone()The plan's initial approach (
__init__defensively copies the dict) fixed the mutable-default sharing bug but brokesklearn.clone()(which requiresself.param == paramafter__init__). Refactored to:None); signatures aredict | None = None._resolve_clusterless_algorithm_paramsand_resolve_sorted_spikes_algorithm_paramsonClusterlessDetector/SortedSpikesDetector. Each returns a fresh dict at every call, so the mutable-default-sharing bug stays fixed._DetectorBasecall the helpers.User-visible: a user-supplied params dict is no longer defensively copied at construction.
d.clusterless_algorithm_params is user_dictafter construction (sklearn convention). Mutations the user makes between construction and fit propagate to the fit. To get isolated state, pass a fresh dict (dict(my_params)).Fitted-attribute convention
estimate_parametersno longer overwritesself.discrete_initial_conditions(the user's constructor argument). The fitted distribution lives onself.discrete_initial_conditions_(trailing underscore, sklearn convention). This restoresget_params()correctness and letssklearn.clone(fitted_model)return an unfitted estimator with the user's original spec.Stacked on
PR #33 (
duplicate-path-cleanup, Phase 3). When #33 merges, this PR's base rebases ontomainautomatically.Test plan
test_find_environment_by_name, hashability and equality tests intests/environment/test_track_graph.py.ObservationModelfour-field equality + set membership:tests/test_observation_models.py(new file).algorithm_paramsresolver contract:TestAlgorithmParamsResolutionintests/models/test_base_initialization.py(10 tests coveringNone-default, fresh-copy, no-leak, all five subclasses, andsklearn.clone()round-trip).discrete_initial_conditions_fitted-attribute contract:tests/test_em_predict_consistency.py(3 tests covering constructor-arg preservation, fitted attribute existence,sklearn.clone()after fit).ruff check src/clean;ruff format --check src/clean.-m "not slow": 828 pass, 3 skipped, 1 xfailed (all pre-existing).🤖 Generated with Claude Code