Skip to content

Refactor and fix for JAX compatibility, STDP, and training workflows#772

Merged
chaoming0625 merged 7 commits into
masterfrom
fix
Aug 13, 2025
Merged

Refactor and fix for JAX compatibility, STDP, and training workflows#772
chaoming0625 merged 7 commits into
masterfrom
fix

Conversation

@chaoming0625

@chaoming0625 chaoming0625 commented Aug 7, 2025

Copy link
Copy Markdown
Member

Summary by Sourcery

Refactor and fix example scripts, core utilities, and integrators for consistency, correctness, and compatibility with JAX tree APIs. Migrate manual tree operations to JAX’s tree utilities, unwrap Variables in exponential integrator, correct STDP update signatures, and standardize training workflows with share.save. Bump braintools dependency.

Bug Fixes:

  • Fix argument ordering in STDP update functions and enforce weight as a Variable for STDP layers
  • Replace deprecated tree_flatten/tree_unflatten in Variable and correct tree_map usages in slow point finder to avoid JAX errors

Enhancements:

  • Unify indentation, import ordering, and variable naming across example scripts and fix argument mismatches in multi-scale COBA-HH, vmap_search, and other simulations
  • Migrate manual tree_map/flatten in analysis modules to use JAX tree utilities (jax.tree.map, jax.tree.flatten)
  • Ensure ODE exponential integrator unwraps bm.Variable inputs via a new _as_value helper
  • Apply share.save(fit=True) at the start of back_propagation.fit to correctly track training mode

Build:

  • Bump braintools dependency to >=0.0.9 in requirements and setup.py

@sourcery-ai

sourcery-ai Bot commented Aug 7, 2025

Copy link
Copy Markdown

Reviewer's Guide

This patch standardizes code formatting across example scripts, enforces correct argument handling in STDP and ODE integrators, updates analysis utilities to use JAX tree functions, and bumps the braintools dependency.

Class diagram for standardized class definitions in example scripts

classDiagram
    class BasicBlock {
        +int expansion
        +bool is_last
        +Conv2D conv1
        +BatchNorm2D bn1
        +Conv2D conv2
        +BatchNorm2D bn2
        +Sequential shortcut
        +update(x)
    }
    class Bottleneck {
        +int expansion
        +bool is_last
        +Conv2D conv1
        +BatchNorm2D bn1
        +Conv2D conv2
        +BatchNorm2D bn2
        +Conv2D conv3
        +BatchNorm2D bn3
        +Sequential shortcut
        +update(s, x)
    }
    class ResNet {
        +int in_planes
        +Conv2D conv1
        +BatchNorm2D bn1
        +Sequential layer1
        +Sequential layer2
        +Sequential layer3
        +Sequential layer4
        +AdaptiveAvgPool2d avgpool
        +Dense linear
        +_make_layer(block, planes, num_blocks, stride)
        +get_bn_before_relu()
        +update(s, x, is_feat, preact)
    }
    BasicBlock <|-- ResNet
    Bottleneck <|-- ResNet

    class LIFNode {
        +int size
        +float tau
        +float v_threshold
        +float v_reset
        +bool fire
        +reset_state(batch_size)
        +update(dv)
    }
    class IFNode {
        +int size
        +float v_threshold
        +float v_reset
        +reset_state(batch_size)
        +update(dv)
    }
    class ResNet11 {
        +Conv2d cnn11
        +Sequential lif11
        +AvgPool2d avgpool1
        +IFNode if1
        +Conv2d cnn21
        +Sequential lif21
        +Conv2d cnn22
        +Conv2d shortcut1
        +Sequential lif2
        +Conv2d cnn31
        +Sequential lif31
        +Conv2d cnn32
        +Conv2d shortcut2
        +Sequential lif3
        +Conv2d cnn41
        +Sequential lif41
        +Conv2d cnn42
        +Conv2d shortcut3
        +Sequential lif4
        +Conv2d cnn51
        +Sequential lif51
        +Conv2d cnn52
        +AvgPool2d shortcut4
        +Sequential lif5
        +Dense fc0
        +Sequential lif6
        +Dense fc1
        +LIFNode lif_out
        +update(x)
    }
Loading

Class diagram for ESN and NGRC in echo_state_network.py

classDiagram
    class ESN {
        +Reservoir r
        +Dense o
        +update(x)
    }
    class NGRC {
        +NVAR r
        +Dense o
        +update(x)
    }
Loading

Class diagram for ReducedTRNModel in 3d_reduced_trn_model.py

classDiagram
    class ReducedTRNModel {
        +float IT_th
        +float b
        +float g_T
        +float g_L
        +float E_L
        +float g_KL
        +float E_KL
        +float NaK_th
        +float T
        +float g_Na
        +float E_Na
        +float g_K
        +float phi_m
        +float phi_h
        +float phi_n
        +float E_T
        +float phi_p
        +float phi_q
        +float p_half
        +float p_k
        +float q_half
        +float q_k
        +float C
        +float Vth
        +float area
        +float V_factor
        +float rho_p
        +Variable V
        +Variable y
        +Variable z
        +Variable spike
        +Variable input
        +fV(V, t, y, z, Isyn)
        +fy(y, t, V)
        +fz(z, t, V)
        +derivative(V, y, z, t, Isyn)
        +update(tdi)
    }
Loading

File-Level Changes

Change Details Files
Consistent indentation and import cleanup in example scripts
  • Indent class methods consistently under their definitions
  • Move and de-duplicate imports (e.g., vmap) to top of modules
  • Remove extra blank lines for uniform formatting
examples/dynamics_simulation/multi_scale_COBAHH.py
examples/training_ann_models/mnist_ResNet.py
examples/training_snn_models/spikebased_bp_for_cifar10.py
...plus many other example scripts
Fix STDP weight updates argument ordering and enforce variable type
  • Swap spike and trace arguments in dense_on_post and csr2csc_on_post calls
  • Require self.weight to be a bm.Variable and error otherwise
brainpy/_src/dnn/linear.py
examples/dynamics_simulation/stdp.py
Unwrap Variables when calling exponential ODE integrator
  • Introduce helper function _as_value to extract .value from Variables
  • Apply _as_value to all integrator inputs before f_integral call
brainpy/_src/integrators/ode/exponential.py
Replace custom tree_map with jax.tree utilities in analysis
  • Use jax.tree.map and jax.tree.flatten instead of tree_map
  • Adjust fixed_points and loss computations to use JAX tree functions
brainpy/_src/analysis/highdim/slow_points.py
brainpy/_src/analysis/utils/others.py
brainpy/_src/analysis/utils/measurement.py
Bump braintools dependency to latest version
  • Update braintools requirement to >=0.0.9 in requirements.txt
  • Align install_requires in setup.py with updated version
requirements.txt
setup.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@chaoming0625

Copy link
Copy Markdown
Member Author

@sourcery-ai title

@sourcery-ai sourcery-ai Bot changed the title Fix Refactor and fix for JAX compatibility, STDP, and training workflows Aug 7, 2025

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @chaoming0625 - I've reviewed your changes - here's some feedback:

  • This PR mixes massive stylistic reformattings (indentation, spacing) with critical functional changes—consider splitting formatting-only updates from logic changes into separate commits or PRs to simplify review and isolate potential regressions.
  • The migration from tree_map to jax.tree.map should be applied consistently across all modules and may require updating minimum JAX version in dependencies—please verify compatibility and update requirements if necessary.
  • The reordering of arguments in the STDP dense_on_post calls and the stricter variable‐type checks could introduce subtle bugs—please add a small targeted test or example to validate that STDP weight updates still behave as expected.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- This PR mixes massive stylistic reformattings (indentation, spacing) with critical functional changes—consider splitting formatting-only updates from logic changes into separate commits or PRs to simplify review and isolate potential regressions.
- The migration from tree_map to jax.tree.map should be applied consistently across all modules and may require updating minimum JAX version in dependencies—please verify compatibility and update requirements if necessary.
- The reordering of arguments in the STDP `dense_on_post` calls and the stricter variable‐type checks could introduce subtle bugs—please add a small targeted test or example to validate that STDP weight updates still behave as expected.

## Individual Comments

### Comment 1
<location> `examples/dynamics_training/echo_state_network.py:7` </location>
<code_context>
 import brainpy.math as bm

 bm.set_environment(bm.batching_mode)
+bp.share.save(fit=True)


</code_context>

<issue_to_address>
Calling bp.share.save(fit=True) at the module level may have unintended side effects.

Consider moving this call inside a function or under a __main__ guard to prevent it from running on import.
</issue_to_address>

### Comment 2
<location> `examples/training_ann_models/mnist_ResNet.py:143` </location>
<code_context>
-      self.in_planes = planes * block.expansion
-    return bp.Sequential(*layers)
-
-  def update(self, s, x, is_feat=False, preact=False):
-    out = bm.relu(self.bn1(self.conv1(x)))
-    f0 = out
</code_context>

<issue_to_address>
The update method signature for ResNet and Bottleneck is inconsistent with BasicBlock.

The extra 's' argument in ResNet and Bottleneck's update methods, but not in BasicBlock, may cause confusion or errors when using these blocks interchangeably. Please align the method signatures.

Suggested implementation:

```python
  def update(self, x, is_feat=False, preact=False):
    out = bm.relu(self.bn1(self.conv1(x)))
    f0 = out

```

```python
class Bottleneck(nn.Module):
    # ...
    def update(self, x, is_feat=False, preact=False):
        # implementation
        pass

```

```python
class ResNet(nn.Module):
    # ...
    def update(self, x, is_feat=False, preact=False):
        # implementation
        pass

```

- If there are any calls to `update` within these classes or elsewhere in the file that pass the `s` argument, remove the `s` argument from those calls as well.
- If the `s` argument was used for state or step information, consider how that information should be handled or refactored (e.g., via class attributes or other means).
</issue_to_address>

### Comment 3
<location> `examples/dynamics_simulation/multi_scale_COBAHH.py:167` </location>
<code_context>
-    raise ValueError(f'The directory has been existed: {path}')
-
-  @vmap
-  def run(gE, gI):
-    bm.random.seed(123)
-    circuit = System(bm.asarray(conn_data) * gc,
+    bm.random.seed(seed)
+    circuit = System(gc * bm.asarray(conn_data),
                      bm.asarray(delay_data),
-                     gE=gE, gI=gI)
+                     gEE=gEE, gEI=gEI, gIE=gIE, gII=gII)
     f1 = lambda tdi: bm.concatenate([area.E.spike for area in circuit.areas])
     f2 = lambda tdi: bm.concatenate([area.I.spike for area in circuit.areas])
     runner = bp.DSRunner(
-      circuit,
-      fun_monitors={'exc.spike': f1, 'inh.spike': f2},
-      inputs=[circuit.areas[0].E.input, I, 'iter'],
-      numpy_mon_after_run=False
+        circuit,
+        fun_monitors={'exc.spike': f1, 'inh.spike': f2},
+        inputs=[circuit.areas[0].E.input, inputs, 'iter'],
+        numpy_mon_after_run=False
     )
     runner.run(duration)
-    runner.mon.pop('var_names')
-    return runner.mon
-
</code_context>

<issue_to_address>
The vmap_search function's run method uses bm.random.seed(123) inside a vmap, which may not provide independent randomness.

Seeding inside a vmap can cause all mapped runs to generate identical random sequences. To ensure independent randomness, pass a unique seed or random key to each instance.
</issue_to_address>

### Comment 4
<location> `brainpy/_src/analysis/utils/others.py:119` </location>
<code_context>
         return candidates, keep_ids
     if num_fps <= 1:
         return candidates, keep_ids
-    candidates = tree_map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray))
+    candidates = jax.tree.map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray))

     # If point A and point B are within identical_tol of each other, and the
</code_context>

<issue_to_address>
tree_map replaced with jax.tree.map; check for compatibility with all input types.

Ensure jax.tree.map correctly processes any custom containers or objects in candidates, particularly regarding the is_leaf parameter.

Suggested implementation:

```python
    # Define a robust is_leaf function to handle bm.BaseArray and any other custom containers
    def _is_leaf(node):
        # Extend this check if you have other custom array-like types
        return isinstance(node, bm.BaseArray) or isinstance(node, np.ndarray) or isinstance(node, jnp.ndarray)
    candidates = jax.tree.map(lambda a: np.asarray(a), candidates, is_leaf=_is_leaf)

```

```python
    if keep_ids.shape[0] > 0:
        unique_fps = jax.tree.map(lambda a: a[keep_ids], candidates, is_leaf=_is_leaf)
    else:
        unique_fps = np.array([], dtype=dtype)
    return unique_fps, keep_ids

```
</issue_to_address>

### Comment 5
<location> `brainpy/_src/analysis/utils/model.py:133` </location>
<code_context>
         self.runner = DSRunner(self, monitors=list(initial_vars.keys()), dt=dt, progress_bar=False)

     def update(self):
-        all_vars = list(self.implicit_vars.values())
+        all_vars = [v.value for v in self.implicit_vars.values()]
         for key, intg in self.integrals.items():
             self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=share['dt']))
</code_context>

<issue_to_address>
Changed to use .value for all_vars; this may break if implicit_vars contains non-Variable objects.

Add a type check or fallback to handle cases where items in implicit_vars lack a .value attribute to prevent runtime errors.
</issue_to_address>

### Comment 6
<location> `brainpy/_src/train/back_propagation.py:234` </location>
<code_context>

         """
+
+        share.save(fit=True)
         if shared_args is None:
             shared_args = dict()
</code_context>

<issue_to_address>
Calling share.save(fit=True) at the start of fit may have side effects.

If share.save is not idempotent, repeated or concurrent calls to fit could cause issues. Please review its behavior.
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

import brainpy.math as bm

bm.set_environment(bm.batching_mode)
bp.share.save(fit=True)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Calling bp.share.save(fit=True) at the module level may have unintended side effects.

Consider moving this call inside a function or under a main guard to prevent it from running on import.

return bp.Sequential(*layers)

def update(self, s, x, is_feat=False, preact=False):
out = bm.relu(self.bn1(self.conv1(x)))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: The update method signature for ResNet and Bottleneck is inconsistent with BasicBlock.

The extra 's' argument in ResNet and Bottleneck's update methods, but not in BasicBlock, may cause confusion or errors when using these blocks interchangeably. Please align the method signatures.

Suggested implementation:

  def update(self, x, is_feat=False, preact=False):
    out = bm.relu(self.bn1(self.conv1(x)))
    f0 = out
class Bottleneck(nn.Module):
    # ...
    def update(self, x, is_feat=False, preact=False):
        # implementation
        pass
class ResNet(nn.Module):
    # ...
    def update(self, x, is_feat=False, preact=False):
        # implementation
        pass
  • If there are any calls to update within these classes or elsewhere in the file that pass the s argument, remove the s argument from those calls as well.
  • If the s argument was used for state or step information, consider how that information should be handled or refactored (e.g., via class attributes or other means).

Comment on lines +167 to +176
circuit = System(bm.asarray(conn_data) * gc,
bm.asarray(delay_data),
gE=gE, gI=gI)
f1 = lambda tdi: bm.concatenate([area.E.spike for area in circuit.areas])
f2 = lambda tdi: bm.concatenate([area.I.spike for area in circuit.areas])
runner = bp.DSRunner(
circuit,
fun_monitors={'exc.spike': f1, 'inh.spike': f2},
inputs=[circuit.areas[0].E.input, I, 'iter'],
numpy_mon_after_run=False

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): The vmap_search function's run method uses bm.random.seed(123) inside a vmap, which may not provide independent randomness.

Seeding inside a vmap can cause all mapped runs to generate identical random sequences. To ensure independent randomness, pass a unique seed or random key to each instance.

@@ -116,7 +115,7 @@ def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]],
return candidates, keep_ids
if num_fps <= 1:
return candidates, keep_ids

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: tree_map replaced with jax.tree.map; check for compatibility with all input types.

Ensure jax.tree.map correctly processes any custom containers or objects in candidates, particularly regarding the is_leaf parameter.

Suggested implementation:

    # Define a robust is_leaf function to handle bm.BaseArray and any other custom containers
    def _is_leaf(node):
        # Extend this check if you have other custom array-like types
        return isinstance(node, bm.BaseArray) or isinstance(node, np.ndarray) or isinstance(node, jnp.ndarray)
    candidates = jax.tree.map(lambda a: np.asarray(a), candidates, is_leaf=_is_leaf)
    if keep_ids.shape[0] > 0:
        unique_fps = jax.tree.map(lambda a: a[keep_ids], candidates, is_leaf=_is_leaf)
    else:
        unique_fps = np.array([], dtype=dtype)
    return unique_fps, keep_ids

self.runner = DSRunner(self, monitors=list(initial_vars.keys()), dt=dt, progress_bar=False)

def update(self):
all_vars = list(self.implicit_vars.values())

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Changed to use .value for all_vars; this may break if implicit_vars contains non-Variable objects.

Add a type check or fallback to handle cases where items in implicit_vars lack a .value attribute to prevent runtime errors.

plt.plot(ts, targets + 2 * np.arange(0, targets.shape[1]), 'g')
plt.plot(ts, outs + 2 * np.arange(0, outs.shape[1]), 'r')
plt.xlim((0, duration))
plt.title('Target (green), Output (red)')

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): Extract duplicate code into function (extract-duplicate-method)

evals, _ = np.linalg.eig(bm.as_numpy(net.w_rr))
plt.subplot(223)
plt.plot(np.real(evals), np.imag(evals), 'o')
plt.plot(x_circ, y_circ, 'k')

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): Extract duplicate code into function (extract-duplicate-method)

Comment on lines +47 to +50
if self.is_last:
return out, preact
else:
return out

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Replace if statement with if expression (assign-if-exp)

Suggested change
if self.is_last:
return out, preact
else:
return out
return (out, preact) if self.is_last else out

Comment on lines +85 to +88
if self.is_last:
return out, preact
else:
return out

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Replace if statement with if expression (assign-if-exp)

Suggested change
if self.is_last:
return out, preact
else:
return out
return (out, preact) if self.is_last else out

Comment on lines +90 to +93
if self.fire:
return spike
else:
return self.v.value

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Replace if statement with if expression (assign-if-exp)

Suggested change
if self.fire:
return spike
else:
return self.v.value
return spike if self.fire else self.v.value

Routhleck and others added 4 commits August 13, 2025 14:04
Updated `analysis.ipynb` and `simulation.ipynb` to reflect BrainPy version 3.0.0. Added new cell and output metadata, and included warning output for missing IProgress in Jupyter. Minor cell metadata and output changes improve notebook compatibility and documentation accuracy.
@chaoming0625 chaoming0625 merged commit 6c6a79e into master Aug 13, 2025
49 of 50 checks passed
@chaoming0625 chaoming0625 deleted the fix branch August 13, 2025 07:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants