Refactor STDP weight update logic and improve error handling, requiring brainevent>=0.0.4#771
Conversation
…ng brainevent>=0.0.4
Reviewer's GuideRefactors and centralizes STDP weight update logic by removing taichi-based kernels and enforcing variable type checks; updates plasticity projections and STDP tests to use JAX-compatible data access and TrainingMode; bumps minimum versions for brainevent and braintools; and deprecates tracing_variable support. Class diagram for updated STDP weight update logicclassDiagram
class STDPPlasticity {
+update()
}
class Comm {
+stdp_update(on_pre, on_post, w_min, w_max)
}
STDPPlasticity --> Comm : uses
class Variable {
}
Comm --> Variable : checks W is Variable
Class diagram for deprecated tracing_variable and error handlingclassDiagram
class ObjectTransformBase {
+tracing_variable(name, var, shape)
}
ObjectTransformBase : tracing_variable() now raises NotImplementedError
Class diagram for AllToAll and CSRLinear STDP update changesclassDiagram
class AllToAll {
+stdp_update(on_pre, on_post, w_min, w_max)
}
class CSRLinear {
+stdp_update(on_pre, on_post, w_min, w_max)
}
AllToAll --> CSRLinear : similar STDP update logic
AllToAll : uses csr_on_pre, csr2csc_on_post
CSRLinear : uses csr_on_pre, csr2csc_on_post
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
requiring brainevent>=0.0.4 |
There was a problem hiding this comment.
Hey @chaoming0625 - I've reviewed your changes - here's some feedback:
- Consider defaulting the
modeparameter of your communication layers toTrainingModeso existing code (and tests) won’t have to be updated everywhere to explicitly passmode=bm.TrainingMode(). - Before turning
tracing_variableinto a hardNotImplementedError, it would be helpful to emit a deprecation warning in an earlier release to give users time to migrate. - Since you’ve removed the Taichi‐based STDP kernels entirely, you may want to provide a pure‐Python fallback or guard so users without Taichi/Braintaichi support still get CPU updates rather than a runtime error.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- Consider defaulting the `mode` parameter of your communication layers to `TrainingMode` so existing code (and tests) won’t have to be updated everywhere to explicitly pass `mode=bm.TrainingMode()`.
- Before turning `tracing_variable` into a hard `NotImplementedError`, it would be helpful to emit a deprecation warning in an earlier release to give users time to migrate.
- Since you’ve removed the Taichi‐based STDP kernels entirely, you may want to provide a pure‐Python fallback or guard so users without Taichi/Braintaichi support still get CPU updates rather than a runtime error.
## Individual Comments
### Comment 1
<location> `brainpy/_src/dyn/projections/plasticity.py:212` </location>
<code_context>
# post spikes
if not hasattr(self.refs['post'], 'spike'):
raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.')
- post_spike = self.refs['post'].spike
+ post_spike = self.refs['post'].spike.value
# weight updates
</code_context>
<issue_to_address>
Accessing .value on spike and trace is now required; ensure all code paths provide compatible objects.
If spike or trace are not always Variable-like, this change could raise AttributeError. Consider adding type checks or conversions to handle such cases.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| # post spikes | ||
| if not hasattr(self.refs['post'], 'spike'): | ||
| raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.') | ||
| post_spike = self.refs['post'].spike |
There was a problem hiding this comment.
issue (bug_risk): Accessing .value on spike and trace is now required; ensure all code paths provide compatible objects.
If spike or trace are not always Variable-like, this change could raise AttributeError. Consider adding type checks or conversions to handle such cases.
| if comm_method == 'all2all': | ||
| comm = bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)) | ||
| comm = bp.dnn.AllToAll( | ||
| self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1), | ||
| mode=bm.TrainingMode() | ||
| ) | ||
| elif comm_method == 'csr': | ||
| if syn_model == 'exp': | ||
| comm = bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), | ||
| weight=bp.init.Uniform(0., 0.1)) | ||
| comm = bp.dnn.EventCSRLinear( | ||
| bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), | ||
| weight=bp.init.Uniform(0., 0.1), | ||
| mode=bm.TrainingMode() | ||
| ) | ||
| else: | ||
| comm = bp.dnn.CSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), | ||
| weight=bp.init.Uniform(0., 0.1)) | ||
| comm = bp.dnn.CSRLinear( | ||
| bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), | ||
| weight=bp.init.Uniform(0., 0.1), | ||
| mode=bm.TrainingMode() | ||
| ) | ||
| elif comm_method == 'masked_linear': | ||
| comm = bp.dnn.MaskedLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), | ||
| weight=bp.init.Uniform(0., 0.1)) | ||
| comm = bp.dnn.MaskedLinear( | ||
| bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), | ||
| weight=bp.init.Uniform(0., 0.1), | ||
| mode=bm.TrainingMode() | ||
| ) | ||
| elif comm_method == 'dense': | ||
| comm = bp.dnn.Dense(self.pre.num, self.post.num, W_initializer=bp.init.Uniform(.1, 0.1)) | ||
| comm = bp.dnn.Dense( | ||
| self.pre.num, self.post.num, W_initializer=bp.init.Uniform(.1, 0.1), | ||
| mode=bm.TrainingMode() | ||
| ) | ||
| elif comm_method == 'one2one': | ||
| comm = bp.dnn.OneToOne(self.pre.num, weight=bp.init.Uniform(.1, 0.1)) | ||
| comm = bp.dnn.OneToOne(self.pre.num, weight=bp.init.Uniform(.1, 0.1), mode=bm.TrainingMode()) | ||
| else: | ||
| raise ValueError |
There was a problem hiding this comment.
issue (code-quality): Avoid conditionals in tests. (no-conditionals-in-tests)
Explanation
Avoid complex code, like conditionals, in test functions.Google's software engineering guidelines says:
"Clear tests are trivially correct upon inspection"
To reach that avoid complex code in tests:
- loops
- conditionals
Some ways to fix this:
- Use parametrized tests to get rid of the loop.
- Move the complex logic into helpers.
- Move the complex part into pytest fixtures.
Complexity is most often introduced in the form of logic. Logic is defined via the imperative parts of programming languages such as operators, loops, and conditionals. When a piece of code contains logic, you need to do a bit of mental computation to determine its result instead of just reading it off of the screen. It doesn't take much logic to make a test more difficult to reason about.
Software Engineering at Google / Don't Put Logic in Tests
| if syn_model == 'exp': | ||
| comm = bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), | ||
| weight=bp.init.Uniform(0., 0.1)) | ||
| comm = bp.dnn.EventCSRLinear( | ||
| bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), | ||
| weight=bp.init.Uniform(0., 0.1), | ||
| mode=bm.TrainingMode() | ||
| ) | ||
| else: | ||
| comm = bp.dnn.CSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), | ||
| weight=bp.init.Uniform(0., 0.1)) | ||
| comm = bp.dnn.CSRLinear( | ||
| bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), | ||
| weight=bp.init.Uniform(0., 0.1), | ||
| mode=bm.TrainingMode() | ||
| ) |
There was a problem hiding this comment.
issue (code-quality): Avoid conditionals in tests. (no-conditionals-in-tests)
Explanation
Avoid complex code, like conditionals, in test functions.Google's software engineering guidelines says:
"Clear tests are trivially correct upon inspection"
To reach that avoid complex code in tests:
- loops
- conditionals
Some ways to fix this:
- Use parametrized tests to get rid of the loop.
- Move the complex logic into helpers.
- Move the complex part into pytest fixtures.
Complexity is most often introduced in the form of logic. Logic is defined via the imperative parts of programming languages such as operators, loops, and conditionals. When a piece of code contains logic, you need to do a bit of mental computation to determine its result instead of just reading it off of the screen. It doesn't take much logic to make a test more difficult to reason about.
Software Engineering at Google / Don't Put Logic in Tests
| raise ValueError(f'Cannot update the weight of a constant node.') | ||
| if not isinstance(self.W, bm.Variable): | ||
| self.tracing_variable('W', self.W, self.W.shape) | ||
| raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.') |
There was a problem hiding this comment.
suggestion (code-quality): Replace f-string with no interpolated values with string (remove-redundant-fstring)
| raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.') | |
| raise ValueError( | |
| 'When using STDP to update synaptic weights, the weight must be a variable.' | |
| ) |
Summary by Sourcery
Refactor STDP weight update logic to use new JAX-based primitives, improve error handling for weight updates, update plasticity pipeline and tests for TrainingMode, and bump minimum versions of brainevent and braintools.
Enhancements:
Build:
Tests:
Chores:
tracing_variable()is no longer supported