Skip to content

Commit 6c6a79e

Browse files
Refactor and fix for JAX compatibility, STDP, and training workflows (#772)
* Replace deprecated jax.tree_util functions with jax.tree functions for improved compatibility * Refactor code for improved readability and consistency across multiple files * Fix STDP weight update logic and improve dataset loading paths * Update quickstart notebooks for BrainPy 3.0.0 Updated `analysis.ipynb` and `simulation.ipynb` to reflect BrainPy version 3.0.0. Added new cell and output metadata, and included warning output for missing IProgress in Jupyter. Minor cell metadata and output changes improve notebook compatibility and documentation accuracy. * Fix test_controls.py to access var_a.value for correct assertions --------- Co-authored-by: routhleck <sichaohe@stu.pku.edu.cn>
1 parent ca84aa1 commit 6c6a79e

42 files changed

Lines changed: 3232 additions & 3296 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

brainpy/_src/analysis/highdim/slow_points.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import jax.numpy as jnp
1111
import numpy as np
1212
from jax.scipy.optimize import minimize
13-
from jax.tree_util import tree_flatten, tree_map
1413

1514
import brainpy._src.math as bm
1615
from brainpy import optim, losses
@@ -265,7 +264,7 @@ def opt_losses(self, val):
265264
@property
266265
def fixed_points(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
267266
"""The final fixed points found."""
268-
return tree_map(lambda a: np.asarray(a), self._fixed_points)
267+
return jax.tree.map(lambda a: np.asarray(a), self._fixed_points)
269268

270269
@fixed_points.setter
271270
def fixed_points(self, val):
@@ -339,11 +338,11 @@ def find_fps_with_gd_method(
339338
num_candidate = self._check_candidates(candidates)
340339
if not (isinstance(candidates, (bm.ndarray, jnp.ndarray, np.ndarray)) or isinstance(candidates, dict)):
341340
raise ValueError('Candidates must be instance of ArrayType or dict of ArrayType.')
342-
fixed_points = tree_map(lambda a: bm.TrainVar(a), candidates, is_leaf=lambda x: isinstance(x, bm.BaseArray))
341+
fixed_points = jax.tree.map(lambda a: bm.TrainVar(a), candidates, is_leaf=lambda x: isinstance(x, bm.BaseArray))
343342
f_eval_loss = self._get_f_eval_loss()
344343

345344
def f_loss():
346-
return f_eval_loss(tree_map(lambda a: bm.as_jax(a),
345+
return f_eval_loss(jax.tree.map(lambda a: bm.as_jax(a),
347346
fixed_points,
348347
is_leaf=lambda x: isinstance(x, bm.BaseArray))).mean()
349348

@@ -387,10 +386,10 @@ def batch_train(start_i, n_batch):
387386
f'is below tolerance {tolerance:0.10f}.')
388387

389388
self._opt_losses = jnp.concatenate(opt_losses)
390-
self._losses = f_eval_loss(tree_map(lambda a: bm.as_jax(a),
389+
self._losses = f_eval_loss(jax.tree.map(lambda a: bm.as_jax(a),
391390
fixed_points,
392391
is_leaf=lambda x: isinstance(x, bm.BaseArray)))
393-
self._fixed_points = tree_map(lambda a: bm.as_jax(a),
392+
self._fixed_points = jax.tree.map(lambda a: bm.as_jax(a),
394393
fixed_points,
395394
is_leaf=lambda x: isinstance(x, bm.BaseArray))
396395
self._selected_ids = jnp.arange(num_candidate)
@@ -429,7 +428,7 @@ def find_fps_with_opt_solver(
429428
print(f"Optimizing with {opt_solver} to find fixed points:")
430429

431430
# optimizing
432-
res = f_opt(tree_map(lambda a: bm.as_jax(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray)))
431+
res = f_opt(jax.tree.map(lambda a: bm.as_jax(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray)))
433432

434433
# results
435434
valid_ids = jnp.where(res.success)[0]
@@ -467,7 +466,7 @@ def filter_loss(self, tolerance: float = 1e-5):
467466
num_fps = self._fixed_points.shape[0]
468467
ids = self._losses < tolerance
469468
keep_ids = bm.as_jax(bm.where(ids)[0])
470-
self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points)
469+
self._fixed_points = jax.tree.map(lambda a: a[keep_ids], self._fixed_points)
471470
self._losses = self._losses[keep_ids]
472471
self._selected_ids = self._selected_ids[keep_ids]
473472
if self.verbose:
@@ -490,7 +489,7 @@ def keep_unique(self, tolerance: float = 2.5e-2):
490489
else:
491490
num_fps = self._fixed_points.shape[0]
492491
fps, keep_ids = utils.keep_unique(self.fixed_points, tolerance=tolerance)
493-
self._fixed_points = tree_map(lambda a: jnp.asarray(a), fps)
492+
self._fixed_points = jax.tree.map(lambda a: jnp.asarray(a), fps)
494493
self._losses = self._losses[keep_ids]
495494
self._selected_ids = self._selected_ids[keep_ids]
496495
if self.verbose:
@@ -525,7 +524,7 @@ def exclude_outliers(self, tolerance: float = 1e0):
525524

526525
# Return data with outliers removed and indices of kept datapoints.
527526
keep_ids = np.where(closest_neighbor < tolerance)[0]
528-
self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points)
527+
self._fixed_points = jax.tree.map(lambda a: a[keep_ids], self._fixed_points)
529528
self._selected_ids = self._selected_ids[keep_ids]
530529
self._losses = self._losses[keep_ids]
531530

@@ -562,11 +561,11 @@ def compute_jacobians(
562561
"""
563562
# check data
564563
info = np.asarray([(l.ndim, l.shape[0])
565-
for l in tree_flatten(points, is_leaf=lambda a: isinstance(a, bm.BaseArray))[0]])
564+
for l in jax.tree.flatten(points, is_leaf=lambda a: isinstance(a, bm.BaseArray))[0]])
566565
ndim = np.unique(info[:, 0])
567566
if len(ndim) != 1: raise ValueError(f'Get multiple dimension of the evaluated points. {ndim}')
568567
if ndim[0] == 1:
569-
points = tree_map(lambda a: bm.asarray([a]), points)
568+
points = jax.tree.map(lambda a: bm.asarray([a]), points)
570569
num_point = 1
571570
elif ndim[0] == 2:
572571
nsize = np.unique(info[:, 1])

brainpy/_src/analysis/utils/measurement.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import jax
77
import jax.numpy as jnp
88
import numpy as np
9-
from jax.tree_util import tree_flatten
109

1110
import brainpy._src.math as bm
1211
from brainpy.tools import numba_jit
@@ -112,7 +111,7 @@ def euclidean_distance_jax(points: Union[jnp.ndarray, bm.ndarray], num_point=Non
112111
num_point = points.shape[0]
113112
indices = jnp.triu_indices(num_point)
114113
dist_mat = bm.zeros((num_point, num_point))
115-
leaves, _ = tree_flatten(points)
114+
leaves, _ = jax.tree.flatten(points)
116115
dist_mat[indices] = _ed(*indices, leaves)
117116
dist_mat = jnp.maximum(dist_mat.value, dist_mat.value.T)
118117
return dist_mat

brainpy/_src/analysis/utils/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):
130130
self.runner = DSRunner(self, monitors=list(initial_vars.keys()), dt=dt, progress_bar=False)
131131

132132
def update(self):
133-
all_vars = list(self.implicit_vars.values())
133+
all_vars = [v.value for v in self.implicit_vars.values()]
134134
for key, intg in self.integrals.items():
135135
self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=share['dt']))
136136

brainpy/_src/analysis/utils/others.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import jax
66
import jax.numpy as jnp
77
import numpy as np
8-
from jax.tree_util import tree_map
98

109
import brainpy.math as bm
1110
from .function import f_without_jaxarray_return
@@ -116,7 +115,7 @@ def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]],
116115
return candidates, keep_ids
117116
if num_fps <= 1:
118117
return candidates, keep_ids
119-
candidates = tree_map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray))
118+
candidates = jax.tree.map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray))
120119

121120
# If point A and point B are within identical_tol of each other, and the
122121
# A is first in the list, we keep A.
@@ -129,7 +128,7 @@ def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]],
129128
all_drop_idxs += list(drop_idxs)
130129
keep_ids = np.setdiff1d(example_idxs, np.unique(all_drop_idxs))
131130
if keep_ids.shape[0] > 0:
132-
unique_fps = tree_map(lambda a: a[keep_ids], candidates)
131+
unique_fps = jax.tree.map(lambda a: a[keep_ids], candidates)
133132
else:
134133
unique_fps = np.array([], dtype=dtype)
135134
return unique_fps, keep_ids

brainpy/_src/dnn/linear.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def stdp_update(
217217
if on_post is not None:
218218
spike = on_post['spike']
219219
trace = on_post['trace']
220-
self.W.value = dense_on_post(self.W.value, spike, trace, w_min, w_max)
220+
self.W.value = dense_on_post(self.W.value, trace, spike, w_min, w_max)
221221

222222

223223
Linear = Dense
@@ -303,18 +303,16 @@ def stdp_update(
303303
w_min: numbers.Number = None,
304304
w_max: numbers.Number = None
305305
):
306-
if isinstance(self.weight, float):
307-
raise ValueError(f'Cannot update the weight of a constant node.')
308306
if not isinstance(self.weight, bm.Variable):
309-
self.tracing_variable('weight', self.weight, self.weight.shape)
307+
raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.')
310308
if on_pre is not None:
311309
spike = on_pre['spike']
312310
trace = on_pre['trace']
313311
self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max)
314312
if on_post is not None:
315313
spike = on_post['spike']
316314
trace = on_post['trace']
317-
self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
315+
self.weight.value = dense_on_post(self.weight.value, trace, spike, w_min, w_max)
318316

319317

320318
class OneToOne(Layer, SupportSTDP):
@@ -442,7 +440,7 @@ def stdp_update(
442440
if on_post is not None:
443441
spike = on_post['spike']
444442
trace = on_post['trace']
445-
self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
443+
self.weight.value = dense_on_post(self.weight.value, trace, spike, w_min, w_max)
446444

447445

448446
class _CSRLayer(Layer, SupportSTDP):
@@ -503,7 +501,7 @@ def stdp_update(
503501
trace = on_post['trace']
504502
self.weight.value = csr2csc_on_post(
505503
self.weight.value, self._pre_ids, self._post_indptr,
506-
self.w_indices, spike, trace, w_min, w_max,
504+
self.w_indices, trace, spike, w_min, w_max,
507505
shape=(trace.shape[0], spike.shape[0]),
508506
)
509507

brainpy/_src/integrators/ode/exponential.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,10 @@ def integral_func(*args, **kwargs):
330330
for i, parse in enumerate(parses):
331331
f_integral, vars_, pars_ = parse
332332
vps = vars_ + pars_ + [C.DT]
333-
r = f_integral(params_in[vps[0]], **{arg: params_in[arg] for arg in vps[1:] if arg in params_in})
333+
r = f_integral(
334+
_as_value(params_in[vps[0]]),
335+
**{arg: _as_value(params_in[arg]) for arg in vps[1:] if arg in params_in}
336+
)
334337
results.append(r)
335338
return results if len(self.variables) > 1 else results[0]
336339

@@ -370,3 +373,10 @@ def integral(*args, **kwargs):
370373
register_ode_integrator('exp_euler', ExponentialEuler)
371374
register_ode_integrator('exp_euler_auto', ExponentialEuler)
372375
register_ode_integrator('exp_auto', ExponentialEuler)
376+
377+
378+
def _as_value(x):
379+
if isinstance(x, bm.Variable):
380+
return x.value
381+
else:
382+
return x

brainpy/_src/math/object_transform/tests/test_controls.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def f1():
169169
lambda: 4, lambda: 5])
170170

171171
self.assertTrue(f(11) == 1)
172-
print(var_a)
173-
self.assertTrue(bm.all(var_a == 1))
172+
print(var_a.value)
173+
self.assertTrue(bm.all(var_a.value == 1))
174174
self.assertTrue(f(1) == 4)
175175
self.assertTrue(f(-1) == 5)
176176

brainpy/_src/math/object_transform/variables.py

Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -349,63 +349,63 @@ def value(self, v):
349349
self._been_writen = True # set the flag
350350
self._write_value(v) # write the value
351351

352-
def tree_flatten(self):
353-
"""Flattens this variable.
354-
355-
Returns:
356-
A pair where the first element is a list of leaf values
357-
and the second element is a treedef representing the
358-
structure of the flattened tree.
359-
"""
360-
return (self._value,), None
361-
362-
@classmethod
363-
def tree_unflatten(cls, aux_data, flat_contents):
364-
"""Reconstructs a variable from the aux_data and the leaves.
365-
366-
Args:
367-
aux_data:
368-
flat_contents:
369-
370-
Returns:
371-
The variable.
372-
"""
373-
return cls(*flat_contents)
374-
375-
def clone(self) -> 'Variable':
376-
"""Clone the variable. """
377-
r = type(self)(jnp.array(self.value, copy=True), batch_axis=self.batch_axis)
378-
return r
379-
380-
def __eq__(self, other):
381-
"""Override State's __eq__ to use BaseArray behavior for element-wise comparison."""
382-
from brainpy._src.math.ndarray import _check_input_array, _return
383-
return _return(self.value == _check_input_array(other))
384-
385-
def __ne__(self, other):
386-
"""Override State's __ne__ to use BaseArray behavior for element-wise comparison."""
387-
from brainpy._src.math.ndarray import _check_input_array, _return
388-
return _return(self.value != _check_input_array(other))
389-
390-
def __lt__(self, other):
391-
"""Override State's __lt__ to use BaseArray behavior for element-wise comparison."""
392-
from brainpy._src.math.ndarray import _check_input_array, _return
393-
return _return(self.value < _check_input_array(other))
394-
395-
def __le__(self, other):
396-
"""Override State's __le__ to use BaseArray behavior for element-wise comparison."""
397-
from brainpy._src.math.ndarray import _check_input_array, _return
398-
return _return(self.value <= _check_input_array(other))
399-
400-
def __gt__(self, other):
401-
"""Override State's __gt__ to use BaseArray behavior for element-wise comparison."""
402-
from brainpy._src.math.ndarray import _check_input_array, _return
403-
return _return(self.value > _check_input_array(other))
404-
405-
def __ge__(self, other):
406-
"""Override State's __ge__ to use BaseArray behavior for element-wise comparison."""
407-
from brainpy._src.math.ndarray import _check_input_array, _return
408-
return _return(self.value >= _check_input_array(other))
352+
# def tree_flatten(self):
353+
# """Flattens this variable.
354+
#
355+
# Returns:
356+
# A pair where the first element is a list of leaf values
357+
# and the second element is a treedef representing the
358+
# structure of the flattened tree.
359+
# """
360+
# return (self._value,), None
361+
#
362+
# @classmethod
363+
# def tree_unflatten(cls, aux_data, flat_contents):
364+
# """Reconstructs a variable from the aux_data and the leaves.
365+
#
366+
# Args:
367+
# aux_data:
368+
# flat_contents:
369+
#
370+
# Returns:
371+
# The variable.
372+
# """
373+
# return cls(*flat_contents)
374+
375+
# def clone(self) -> 'Variable':
376+
# """Clone the variable. """
377+
# r = type(self)(jnp.array(self.value, copy=True), batch_axis=self.batch_axis)
378+
# return r
379+
380+
# def __eq__(self, other):
381+
# """Override State's __eq__ to use BaseArray behavior for element-wise comparison."""
382+
# from brainpy._src.math.ndarray import _check_input_array, _return
383+
# return _return(self.value == _check_input_array(other))
384+
#
385+
# def __ne__(self, other):
386+
# """Override State's __ne__ to use BaseArray behavior for element-wise comparison."""
387+
# from brainpy._src.math.ndarray import _check_input_array, _return
388+
# return _return(self.value != _check_input_array(other))
389+
#
390+
# def __lt__(self, other):
391+
# """Override State's __lt__ to use BaseArray behavior for element-wise comparison."""
392+
# from brainpy._src.math.ndarray import _check_input_array, _return
393+
# return _return(self.value < _check_input_array(other))
394+
#
395+
# def __le__(self, other):
396+
# """Override State's __le__ to use BaseArray behavior for element-wise comparison."""
397+
# from brainpy._src.math.ndarray import _check_input_array, _return
398+
# return _return(self.value <= _check_input_array(other))
399+
#
400+
# def __gt__(self, other):
401+
# """Override State's __gt__ to use BaseArray behavior for element-wise comparison."""
402+
# from brainpy._src.math.ndarray import _check_input_array, _return
403+
# return _return(self.value > _check_input_array(other))
404+
#
405+
# def __ge__(self, other):
406+
# """Override State's __ge__ to use BaseArray behavior for element-wise comparison."""
407+
# from brainpy._src.math.ndarray import _check_input_array, _return
408+
# return _return(self.value >= _check_input_array(other))
409409

410410

411411
def _get_dtype(v):

brainpy/_src/train/back_propagation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def fit(
230230
Please set batch size in your dataset.
231231
232232
"""
233+
234+
share.save(fit=True)
233235
if shared_args is None:
234236
shared_args = dict()
235237
shared_args['fit'] = shared_args.get('fit', True)

0 commit comments

Comments
 (0)