Refactor and fix for JAX compatibility, STDP, and training workflows#772
Conversation
Reviewer's GuideThis patch standardizes code formatting across example scripts, enforces correct argument handling in STDP and ODE integrators, updates analysis utilities to use JAX tree functions, and bumps the braintools dependency. Class diagram for standardized class definitions in example scriptsclassDiagram
class BasicBlock {
+int expansion
+bool is_last
+Conv2D conv1
+BatchNorm2D bn1
+Conv2D conv2
+BatchNorm2D bn2
+Sequential shortcut
+update(x)
}
class Bottleneck {
+int expansion
+bool is_last
+Conv2D conv1
+BatchNorm2D bn1
+Conv2D conv2
+BatchNorm2D bn2
+Conv2D conv3
+BatchNorm2D bn3
+Sequential shortcut
+update(s, x)
}
class ResNet {
+int in_planes
+Conv2D conv1
+BatchNorm2D bn1
+Sequential layer1
+Sequential layer2
+Sequential layer3
+Sequential layer4
+AdaptiveAvgPool2d avgpool
+Dense linear
+_make_layer(block, planes, num_blocks, stride)
+get_bn_before_relu()
+update(s, x, is_feat, preact)
}
BasicBlock <|-- ResNet
Bottleneck <|-- ResNet
class LIFNode {
+int size
+float tau
+float v_threshold
+float v_reset
+bool fire
+reset_state(batch_size)
+update(dv)
}
class IFNode {
+int size
+float v_threshold
+float v_reset
+reset_state(batch_size)
+update(dv)
}
class ResNet11 {
+Conv2d cnn11
+Sequential lif11
+AvgPool2d avgpool1
+IFNode if1
+Conv2d cnn21
+Sequential lif21
+Conv2d cnn22
+Conv2d shortcut1
+Sequential lif2
+Conv2d cnn31
+Sequential lif31
+Conv2d cnn32
+Conv2d shortcut2
+Sequential lif3
+Conv2d cnn41
+Sequential lif41
+Conv2d cnn42
+Conv2d shortcut3
+Sequential lif4
+Conv2d cnn51
+Sequential lif51
+Conv2d cnn52
+AvgPool2d shortcut4
+Sequential lif5
+Dense fc0
+Sequential lif6
+Dense fc1
+LIFNode lif_out
+update(x)
}
Class diagram for ESN and NGRC in echo_state_network.pyclassDiagram
class ESN {
+Reservoir r
+Dense o
+update(x)
}
class NGRC {
+NVAR r
+Dense o
+update(x)
}
Class diagram for ReducedTRNModel in 3d_reduced_trn_model.pyclassDiagram
class ReducedTRNModel {
+float IT_th
+float b
+float g_T
+float g_L
+float E_L
+float g_KL
+float E_KL
+float NaK_th
+float T
+float g_Na
+float E_Na
+float g_K
+float phi_m
+float phi_h
+float phi_n
+float E_T
+float phi_p
+float phi_q
+float p_half
+float p_k
+float q_half
+float q_k
+float C
+float Vth
+float area
+float V_factor
+float rho_p
+Variable V
+Variable y
+Variable z
+Variable spike
+Variable input
+fV(V, t, y, z, Isyn)
+fy(y, t, V)
+fz(z, t, V)
+derivative(V, y, z, t, Isyn)
+update(tdi)
}
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
@sourcery-ai title |
There was a problem hiding this comment.
Hey @chaoming0625 - I've reviewed your changes - here's some feedback:
- This PR mixes massive stylistic reformattings (indentation, spacing) with critical functional changes—consider splitting formatting-only updates from logic changes into separate commits or PRs to simplify review and isolate potential regressions.
- The migration from tree_map to jax.tree.map should be applied consistently across all modules and may require updating minimum JAX version in dependencies—please verify compatibility and update requirements if necessary.
- The reordering of arguments in the STDP
dense_on_postcalls and the stricter variable‐type checks could introduce subtle bugs—please add a small targeted test or example to validate that STDP weight updates still behave as expected.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- This PR mixes massive stylistic reformattings (indentation, spacing) with critical functional changes—consider splitting formatting-only updates from logic changes into separate commits or PRs to simplify review and isolate potential regressions.
- The migration from tree_map to jax.tree.map should be applied consistently across all modules and may require updating minimum JAX version in dependencies—please verify compatibility and update requirements if necessary.
- The reordering of arguments in the STDP `dense_on_post` calls and the stricter variable‐type checks could introduce subtle bugs—please add a small targeted test or example to validate that STDP weight updates still behave as expected.
## Individual Comments
### Comment 1
<location> `examples/dynamics_training/echo_state_network.py:7` </location>
<code_context>
import brainpy.math as bm
bm.set_environment(bm.batching_mode)
+bp.share.save(fit=True)
</code_context>
<issue_to_address>
Calling bp.share.save(fit=True) at the module level may have unintended side effects.
Consider moving this call inside a function or under a __main__ guard to prevent it from running on import.
</issue_to_address>
### Comment 2
<location> `examples/training_ann_models/mnist_ResNet.py:143` </location>
<code_context>
- self.in_planes = planes * block.expansion
- return bp.Sequential(*layers)
-
- def update(self, s, x, is_feat=False, preact=False):
- out = bm.relu(self.bn1(self.conv1(x)))
- f0 = out
</code_context>
<issue_to_address>
The update method signature for ResNet and Bottleneck is inconsistent with BasicBlock.
The extra 's' argument in ResNet and Bottleneck's update methods, but not in BasicBlock, may cause confusion or errors when using these blocks interchangeably. Please align the method signatures.
Suggested implementation:
```python
def update(self, x, is_feat=False, preact=False):
out = bm.relu(self.bn1(self.conv1(x)))
f0 = out
```
```python
class Bottleneck(nn.Module):
# ...
def update(self, x, is_feat=False, preact=False):
# implementation
pass
```
```python
class ResNet(nn.Module):
# ...
def update(self, x, is_feat=False, preact=False):
# implementation
pass
```
- If there are any calls to `update` within these classes or elsewhere in the file that pass the `s` argument, remove the `s` argument from those calls as well.
- If the `s` argument was used for state or step information, consider how that information should be handled or refactored (e.g., via class attributes or other means).
</issue_to_address>
### Comment 3
<location> `examples/dynamics_simulation/multi_scale_COBAHH.py:167` </location>
<code_context>
- raise ValueError(f'The directory has been existed: {path}')
-
- @vmap
- def run(gE, gI):
- bm.random.seed(123)
- circuit = System(bm.asarray(conn_data) * gc,
+ bm.random.seed(seed)
+ circuit = System(gc * bm.asarray(conn_data),
bm.asarray(delay_data),
- gE=gE, gI=gI)
+ gEE=gEE, gEI=gEI, gIE=gIE, gII=gII)
f1 = lambda tdi: bm.concatenate([area.E.spike for area in circuit.areas])
f2 = lambda tdi: bm.concatenate([area.I.spike for area in circuit.areas])
runner = bp.DSRunner(
- circuit,
- fun_monitors={'exc.spike': f1, 'inh.spike': f2},
- inputs=[circuit.areas[0].E.input, I, 'iter'],
- numpy_mon_after_run=False
+ circuit,
+ fun_monitors={'exc.spike': f1, 'inh.spike': f2},
+ inputs=[circuit.areas[0].E.input, inputs, 'iter'],
+ numpy_mon_after_run=False
)
runner.run(duration)
- runner.mon.pop('var_names')
- return runner.mon
-
</code_context>
<issue_to_address>
The vmap_search function's run method uses bm.random.seed(123) inside a vmap, which may not provide independent randomness.
Seeding inside a vmap can cause all mapped runs to generate identical random sequences. To ensure independent randomness, pass a unique seed or random key to each instance.
</issue_to_address>
### Comment 4
<location> `brainpy/_src/analysis/utils/others.py:119` </location>
<code_context>
return candidates, keep_ids
if num_fps <= 1:
return candidates, keep_ids
- candidates = tree_map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray))
+ candidates = jax.tree.map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray))
# If point A and point B are within identical_tol of each other, and the
</code_context>
<issue_to_address>
tree_map replaced with jax.tree.map; check for compatibility with all input types.
Ensure jax.tree.map correctly processes any custom containers or objects in candidates, particularly regarding the is_leaf parameter.
Suggested implementation:
```python
# Define a robust is_leaf function to handle bm.BaseArray and any other custom containers
def _is_leaf(node):
# Extend this check if you have other custom array-like types
return isinstance(node, bm.BaseArray) or isinstance(node, np.ndarray) or isinstance(node, jnp.ndarray)
candidates = jax.tree.map(lambda a: np.asarray(a), candidates, is_leaf=_is_leaf)
```
```python
if keep_ids.shape[0] > 0:
unique_fps = jax.tree.map(lambda a: a[keep_ids], candidates, is_leaf=_is_leaf)
else:
unique_fps = np.array([], dtype=dtype)
return unique_fps, keep_ids
```
</issue_to_address>
### Comment 5
<location> `brainpy/_src/analysis/utils/model.py:133` </location>
<code_context>
self.runner = DSRunner(self, monitors=list(initial_vars.keys()), dt=dt, progress_bar=False)
def update(self):
- all_vars = list(self.implicit_vars.values())
+ all_vars = [v.value for v in self.implicit_vars.values()]
for key, intg in self.integrals.items():
self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=share['dt']))
</code_context>
<issue_to_address>
Changed to use .value for all_vars; this may break if implicit_vars contains non-Variable objects.
Add a type check or fallback to handle cases where items in implicit_vars lack a .value attribute to prevent runtime errors.
</issue_to_address>
### Comment 6
<location> `brainpy/_src/train/back_propagation.py:234` </location>
<code_context>
"""
+
+ share.save(fit=True)
if shared_args is None:
shared_args = dict()
</code_context>
<issue_to_address>
Calling share.save(fit=True) at the start of fit may have side effects.
If share.save is not idempotent, repeated or concurrent calls to fit could cause issues. Please review its behavior.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| import brainpy.math as bm | ||
|
|
||
| bm.set_environment(bm.batching_mode) | ||
| bp.share.save(fit=True) |
There was a problem hiding this comment.
issue (bug_risk): Calling bp.share.save(fit=True) at the module level may have unintended side effects.
Consider moving this call inside a function or under a main guard to prevent it from running on import.
| return bp.Sequential(*layers) | ||
|
|
||
| def update(self, s, x, is_feat=False, preact=False): | ||
| out = bm.relu(self.bn1(self.conv1(x))) |
There was a problem hiding this comment.
suggestion: The update method signature for ResNet and Bottleneck is inconsistent with BasicBlock.
The extra 's' argument in ResNet and Bottleneck's update methods, but not in BasicBlock, may cause confusion or errors when using these blocks interchangeably. Please align the method signatures.
Suggested implementation:
def update(self, x, is_feat=False, preact=False):
out = bm.relu(self.bn1(self.conv1(x)))
f0 = outclass Bottleneck(nn.Module):
# ...
def update(self, x, is_feat=False, preact=False):
# implementation
passclass ResNet(nn.Module):
# ...
def update(self, x, is_feat=False, preact=False):
# implementation
pass- If there are any calls to
updatewithin these classes or elsewhere in the file that pass thesargument, remove thesargument from those calls as well. - If the
sargument was used for state or step information, consider how that information should be handled or refactored (e.g., via class attributes or other means).
| circuit = System(bm.asarray(conn_data) * gc, | ||
| bm.asarray(delay_data), | ||
| gE=gE, gI=gI) | ||
| f1 = lambda tdi: bm.concatenate([area.E.spike for area in circuit.areas]) | ||
| f2 = lambda tdi: bm.concatenate([area.I.spike for area in circuit.areas]) | ||
| runner = bp.DSRunner( | ||
| circuit, | ||
| fun_monitors={'exc.spike': f1, 'inh.spike': f2}, | ||
| inputs=[circuit.areas[0].E.input, I, 'iter'], | ||
| numpy_mon_after_run=False |
There was a problem hiding this comment.
issue (bug_risk): The vmap_search function's run method uses bm.random.seed(123) inside a vmap, which may not provide independent randomness.
Seeding inside a vmap can cause all mapped runs to generate identical random sequences. To ensure independent randomness, pass a unique seed or random key to each instance.
| @@ -116,7 +115,7 @@ def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]], | |||
| return candidates, keep_ids | |||
| if num_fps <= 1: | |||
| return candidates, keep_ids | |||
There was a problem hiding this comment.
suggestion: tree_map replaced with jax.tree.map; check for compatibility with all input types.
Ensure jax.tree.map correctly processes any custom containers or objects in candidates, particularly regarding the is_leaf parameter.
Suggested implementation:
# Define a robust is_leaf function to handle bm.BaseArray and any other custom containers
def _is_leaf(node):
# Extend this check if you have other custom array-like types
return isinstance(node, bm.BaseArray) or isinstance(node, np.ndarray) or isinstance(node, jnp.ndarray)
candidates = jax.tree.map(lambda a: np.asarray(a), candidates, is_leaf=_is_leaf) if keep_ids.shape[0] > 0:
unique_fps = jax.tree.map(lambda a: a[keep_ids], candidates, is_leaf=_is_leaf)
else:
unique_fps = np.array([], dtype=dtype)
return unique_fps, keep_ids| self.runner = DSRunner(self, monitors=list(initial_vars.keys()), dt=dt, progress_bar=False) | ||
|
|
||
| def update(self): | ||
| all_vars = list(self.implicit_vars.values()) |
There was a problem hiding this comment.
issue (bug_risk): Changed to use .value for all_vars; this may break if implicit_vars contains non-Variable objects.
Add a type check or fallback to handle cases where items in implicit_vars lack a .value attribute to prevent runtime errors.
| plt.plot(ts, targets + 2 * np.arange(0, targets.shape[1]), 'g') | ||
| plt.plot(ts, outs + 2 * np.arange(0, outs.shape[1]), 'r') | ||
| plt.xlim((0, duration)) | ||
| plt.title('Target (green), Output (red)') |
There was a problem hiding this comment.
issue (code-quality): Extract duplicate code into function (extract-duplicate-method)
| evals, _ = np.linalg.eig(bm.as_numpy(net.w_rr)) | ||
| plt.subplot(223) | ||
| plt.plot(np.real(evals), np.imag(evals), 'o') | ||
| plt.plot(x_circ, y_circ, 'k') |
There was a problem hiding this comment.
issue (code-quality): Extract duplicate code into function (extract-duplicate-method)
| if self.is_last: | ||
| return out, preact | ||
| else: | ||
| return out |
There was a problem hiding this comment.
suggestion (code-quality): Replace if statement with if expression (assign-if-exp)
| if self.is_last: | |
| return out, preact | |
| else: | |
| return out | |
| return (out, preact) if self.is_last else out |
| if self.is_last: | ||
| return out, preact | ||
| else: | ||
| return out |
There was a problem hiding this comment.
suggestion (code-quality): Replace if statement with if expression (assign-if-exp)
| if self.is_last: | |
| return out, preact | |
| else: | |
| return out | |
| return (out, preact) if self.is_last else out |
| if self.fire: | ||
| return spike | ||
| else: | ||
| return self.v.value |
There was a problem hiding this comment.
suggestion (code-quality): Replace if statement with if expression (assign-if-exp)
| if self.fire: | |
| return spike | |
| else: | |
| return self.v.value | |
| return spike if self.fire else self.v.value |
Updated `analysis.ipynb` and `simulation.ipynb` to reflect BrainPy version 3.0.0. Added new cell and output metadata, and included warning output for missing IProgress in Jupyter. Minor cell metadata and output changes improve notebook compatibility and documentation accuracy.
Summary by Sourcery
Refactor and fix example scripts, core utilities, and integrators for consistency, correctness, and compatibility with JAX tree APIs. Migrate manual tree operations to JAX’s tree utilities, unwrap Variables in exponential integrator, correct STDP update signatures, and standardize training workflows with share.save. Bump braintools dependency.
Bug Fixes:
Enhancements:
Build: