[!IMPORTANT] Read
DESIGN.md(§7) for the rationale behind this architecture.
Design rules (🏛️) are universal and apply whenever extracting shared interfaces.
Toolchain rules (🔧) apply to files checked with pyrefly (currently specs/ and
new files; pyrefly is available via pyproject.toml).
- JAX: Stable, backward-compat required. Do not change public APIs (names, interfaces, type signatures) without explicit justification.
- TF: Deprecated. Ignore this directory and its conventions.
- MLX: In progress. See Porting Workflow below.
When porting or refactoring across backends, you are likely in one of these scenarios (or a Full Port combining them):
- Interface Extraction: JAX and MLX implementations exist but don't
inherit from a shared interface. Find shared config fields, layer classes,
methods, and arguments. Codify the shared interface in
specs/*.py. Both backends should inherit from the spec class (in appropriate MRO order; see rule 4), provided it does not break JAX backward compatibility. - Test Unification: JAX and MLX implementations exist but don't share
tests. Given the unified interface in
specs/*.py, refactor shared test logic intospecs/*_behaviors.py. Prefer JAX tests as the basis when they cover equivalent features. - Full Port / Feature Porting: A JAX-only layer needs an MLX port. Start
by abstracting the tests (TDD): codify the interface in
specs/*.pyand tests inspecs/*_behaviors.py(preferring JAX tests as basis), then create or updatemlx/*(_test).py. - Backend-specific supersets: When a backend implements extra features
beyond the shared spec, common functionality goes in
specs/*_behaviors.py, while backend-specific extensions stay in<backend>/*_test.py. Consider whether extended features could be generalized into the shared spec.
- Up-front readability: Backend files must be self-contained. Re-declare
all defaults, docstrings, function signatures, and Config fields. Users
should never need to read
specs/to understand a backend's API.- Exception: Pure functions that are part of the contract all backends
must fulfill (e.g., test utilities like
zip_longest,named_product) may live inspecs/and be aliased by backends.
- Exception: Pure functions that are part of the contract all backends
must fulfill (e.g., test utilities like
- Generics and specialization: Spec classes (layers and Configs) are
generic (e.g., over
DTypeT,SequenceT). Backends specialize with concrete types. - Rigid signatures / LSP: Match spec parameter names and signatures
exactly. No
**kwargs. Include all protocol parameters (e.g.,training: bool) even if unused by a particular backend — this maintains Liskov Substitution Principle compliance. - MRO: The abstract spec class should be the last one inherited. Example:
class StatelessEmitting(Emitting, spec.StatelessEmitting) - Circular import prevention: When submodules import root-level aliases
from
__init__.py, ensure all root-level alias imports are placed at the top of__init__.py, before importing any submodule classes. - Decoupled instantiation: Use
Layer.from_config(config)factory methods on the framework-specific class, notConfig.make(backend=...). Spec configs remain abstract. - Deferred initialization for stateless backends: Backends without eager
parameter allocation (e.g., MLX) should use lazy submodule creation within
_ensure_initializedrather than maintaining a separate wrapper class. The public class accepts aConfigand lazily creates its internal submodules on the first call tolayer(). - Config specs nested: In
specs/files,Configclasses are nested within the layer classes they configure, paralleling the structure in backend implementations.
- Behavior tests via inheritance:
specs/*_behaviors.pydefines backend-agnostic test cases.<backend>/*_test.pyinherits from these.- No
abc.ABCin behavior test classes (they won't be discovered by pytest since files are named*_behaviors.py, not*_test.py). - No cross-importing between behavior files. Prefer duplicating small helpers or using shared bases in non-behavior modules.
- Inherit from
test_utils.SequenceLayerTest(or similar shared base). In<backend>/*_test.py, subclasstest_utils.SequenceLayerTestfirst (MRO convention).
- No
- Backend-native syntax in tests: In
<backend>/*_test.py, use backend-specific types (jnp,sl.Sequence, etc.). Import the backend assl(e.g.,import sequence_layers.mlx as sl). - Avoid
super()in diamond test hierarchies: When dealing with diamond inheritance (test base + backend-specific mock),super()calls can be brittle. Use explicit class delegation (e.g.,backend_sl.types.Stateless.step(self, ...)). - Capture
self.slbefore nested classes: Capturebackend_sl = self.slin the outer method before defining a local mock class (likeDummyLayer) to avoid scoping issues with static analysis tools. - Use
backend.xp/backend.nn: In shared behavior tests, avoid importing backend-specific libraries directly. Useself.sl.backend.xpfor array ops andself.sl.backend.nnfor neural network ops.
- Collocation: Define
ModuleSpecprotocols in the specific spec module they describe (e.g.,specs/simple.py,specs/types.py), not inspecs/__init__.py. __all__fromModuleSpec.__dict__: Files definingModuleSpecshould derive__all__dynamically to keep exports aligned with the protocol.- Protocol alignment: Keep protocols aligned with usage in shared tests.
When exposing new modules or utilities via backend implementations, update
the relevant
ModuleSpec.
Apply to files checked with pyrefly (currently specs/ and new files).
- PEP 695 syntax: Use
class Foo[T]:instead ofTypeVar+Generic[...]. Legacy files may use the older syntax. - Pyrefly priority: Pyrefly over Pylint for structural correctness and
type safety. Use
from typing import ...(noimport typing). Fix warnings up-front; never add# type: ignorewithout justification. If proposing disables, prefer disabling in Pylint over Pyrefly. @overridemandatory: Implementations of abstract methods in backends must be decorated with@override(fromtyping).- Import naming from
specs:- If it is the "specification" for the current file, import as
spec(e.g.,test_utils.pyimportsspecs/test_utils.pyasspec). - Otherwise, import as
<module>_spec(e.g.,test_utils.pyimportsspecs/types.pyastypes_spec). - Within
specs/itself, always use the_specsuffix to avoid ambiguity.
- If it is the "specification" for the current file, import as
- Lint disable policy: Broad-scoped disables are only allowed for these
cases:
specs/*_behaviors.py:# pylint: disable=abstract-methodand# pyrefly: disable=bad-instantiationat the file level (test classes inherit abstract methods implemented in backend test files).ModuleSpecprotocols:# pylint: disable=invalid-nameand# pylint: disable=missing-function-docstringat the class level.- JAX layer implementations (e.g.,
jax/dense.py):# pylint: disable=abstract-method,abstract-class-instantiatedat the file level (Pylint cannot see through Flax's metaclass wrappers; compliance is guaranteed by Pyrefly and runtime tests).
Formatting, linting, static analysis: Scope to the files you modified only.
- Format:
pyink <file>,isort <file>. - Lint:
pylint <file>— fix all warnings. Do not claim "false positive" without demonstrating it. - Static analysis:
pyrefly check <file>— for pyrefly-checked files.
[!IMPORTANT] Do not fix pre-existing errors in files you did not modify.
Tests: Scope depends on what you changed:
| What you changed | Test scope |
|---|---|
<backend>/*.py only |
That backend's *_test.py files |
specs/*.py (protocols) |
Static analysis usually suffices. Run |
<backend>/*_test.py if you added or changed |
|
| abstract methods/signatures. | |
specs/*_behaviors.py |
All inheriting <backend>/*_test.py files |