Skip to content

Commit 0424b1d

Browse files
committed
Fix STDP weight update logic and improve dataset loading paths
1 parent d35417e commit 0424b1d

9 files changed

Lines changed: 80 additions & 114 deletions

File tree

brainpy/_src/dnn/linear.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def stdp_update(
217217
if on_post is not None:
218218
spike = on_post['spike']
219219
trace = on_post['trace']
220-
self.W.value = dense_on_post(self.W.value, spike, trace, w_min, w_max)
220+
self.W.value = dense_on_post(self.W.value, trace, spike, w_min, w_max)
221221

222222

223223
Linear = Dense
@@ -303,18 +303,16 @@ def stdp_update(
303303
w_min: numbers.Number = None,
304304
w_max: numbers.Number = None
305305
):
306-
if isinstance(self.weight, float):
307-
raise ValueError(f'Cannot update the weight of a constant node.')
308306
if not isinstance(self.weight, bm.Variable):
309-
self.tracing_variable('weight', self.weight, self.weight.shape)
307+
raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.')
310308
if on_pre is not None:
311309
spike = on_pre['spike']
312310
trace = on_pre['trace']
313311
self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max)
314312
if on_post is not None:
315313
spike = on_post['spike']
316314
trace = on_post['trace']
317-
self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
315+
self.weight.value = dense_on_post(self.weight.value, trace, spike, w_min, w_max)
318316

319317

320318
class OneToOne(Layer, SupportSTDP):
@@ -442,7 +440,7 @@ def stdp_update(
442440
if on_post is not None:
443441
spike = on_post['spike']
444442
trace = on_post['trace']
445-
self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
443+
self.weight.value = dense_on_post(self.weight.value, trace, spike, w_min, w_max)
446444

447445

448446
class _CSRLayer(Layer, SupportSTDP):
@@ -503,7 +501,7 @@ def stdp_update(
503501
trace = on_post['trace']
504502
self.weight.value = csr2csc_on_post(
505503
self.weight.value, self._pre_ids, self._post_indptr,
506-
self.w_indices, spike, trace, w_min, w_max,
504+
self.w_indices, trace, spike, w_min, w_max,
507505
shape=(trace.shape[0], spike.shape[0]),
508506
)
509507

brainpy/_src/math/object_transform/variables.py

Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -349,63 +349,63 @@ def value(self, v):
349349
self._been_writen = True # set the flag
350350
self._write_value(v) # write the value
351351

352-
def tree_flatten(self):
353-
"""Flattens this variable.
354-
355-
Returns:
356-
A pair where the first element is a list of leaf values
357-
and the second element is a treedef representing the
358-
structure of the flattened tree.
359-
"""
360-
return (self._value,), None
361-
362-
@classmethod
363-
def tree_unflatten(cls, aux_data, flat_contents):
364-
"""Reconstructs a variable from the aux_data and the leaves.
365-
366-
Args:
367-
aux_data:
368-
flat_contents:
369-
370-
Returns:
371-
The variable.
372-
"""
373-
return cls(*flat_contents)
374-
375-
def clone(self) -> 'Variable':
376-
"""Clone the variable. """
377-
r = type(self)(jnp.array(self.value, copy=True), batch_axis=self.batch_axis)
378-
return r
379-
380-
def __eq__(self, other):
381-
"""Override State's __eq__ to use BaseArray behavior for element-wise comparison."""
382-
from brainpy._src.math.ndarray import _check_input_array, _return
383-
return _return(self.value == _check_input_array(other))
384-
385-
def __ne__(self, other):
386-
"""Override State's __ne__ to use BaseArray behavior for element-wise comparison."""
387-
from brainpy._src.math.ndarray import _check_input_array, _return
388-
return _return(self.value != _check_input_array(other))
389-
390-
def __lt__(self, other):
391-
"""Override State's __lt__ to use BaseArray behavior for element-wise comparison."""
392-
from brainpy._src.math.ndarray import _check_input_array, _return
393-
return _return(self.value < _check_input_array(other))
394-
395-
def __le__(self, other):
396-
"""Override State's __le__ to use BaseArray behavior for element-wise comparison."""
397-
from brainpy._src.math.ndarray import _check_input_array, _return
398-
return _return(self.value <= _check_input_array(other))
399-
400-
def __gt__(self, other):
401-
"""Override State's __gt__ to use BaseArray behavior for element-wise comparison."""
402-
from brainpy._src.math.ndarray import _check_input_array, _return
403-
return _return(self.value > _check_input_array(other))
404-
405-
def __ge__(self, other):
406-
"""Override State's __ge__ to use BaseArray behavior for element-wise comparison."""
407-
from brainpy._src.math.ndarray import _check_input_array, _return
408-
return _return(self.value >= _check_input_array(other))
352+
# def tree_flatten(self):
353+
# """Flattens this variable.
354+
#
355+
# Returns:
356+
# A pair where the first element is a list of leaf values
357+
# and the second element is a treedef representing the
358+
# structure of the flattened tree.
359+
# """
360+
# return (self._value,), None
361+
#
362+
# @classmethod
363+
# def tree_unflatten(cls, aux_data, flat_contents):
364+
# """Reconstructs a variable from the aux_data and the leaves.
365+
#
366+
# Args:
367+
# aux_data:
368+
# flat_contents:
369+
#
370+
# Returns:
371+
# The variable.
372+
# """
373+
# return cls(*flat_contents)
374+
375+
# def clone(self) -> 'Variable':
376+
# """Clone the variable. """
377+
# r = type(self)(jnp.array(self.value, copy=True), batch_axis=self.batch_axis)
378+
# return r
379+
380+
# def __eq__(self, other):
381+
# """Override State's __eq__ to use BaseArray behavior for element-wise comparison."""
382+
# from brainpy._src.math.ndarray import _check_input_array, _return
383+
# return _return(self.value == _check_input_array(other))
384+
#
385+
# def __ne__(self, other):
386+
# """Override State's __ne__ to use BaseArray behavior for element-wise comparison."""
387+
# from brainpy._src.math.ndarray import _check_input_array, _return
388+
# return _return(self.value != _check_input_array(other))
389+
#
390+
# def __lt__(self, other):
391+
# """Override State's __lt__ to use BaseArray behavior for element-wise comparison."""
392+
# from brainpy._src.math.ndarray import _check_input_array, _return
393+
# return _return(self.value < _check_input_array(other))
394+
#
395+
# def __le__(self, other):
396+
# """Override State's __le__ to use BaseArray behavior for element-wise comparison."""
397+
# from brainpy._src.math.ndarray import _check_input_array, _return
398+
# return _return(self.value <= _check_input_array(other))
399+
#
400+
# def __gt__(self, other):
401+
# """Override State's __gt__ to use BaseArray behavior for element-wise comparison."""
402+
# from brainpy._src.math.ndarray import _check_input_array, _return
403+
# return _return(self.value > _check_input_array(other))
404+
#
405+
# def __ge__(self, other):
406+
# """Override State's __ge__ to use BaseArray behavior for element-wise comparison."""
407+
# from brainpy._src.math.ndarray import _check_input_array, _return
408+
# return _return(self.value >= _check_input_array(other))
409409

410410

411411
def _get_dtype(v):

brainpy/_src/train/back_propagation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def fit(
230230
Please set batch size in your dataset.
231231
232232
"""
233+
234+
share.save(fit=True)
233235
if shared_args is None:
234236
shared_args = dict()
235237
shared_args['fit'] = shared_args.get('fit', True)

examples/dynamics_simulation/stdp.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,20 @@ def __init__(self, num_poisson, num_lif=1, g_max=0.01):
2020

2121
# neuron groups
2222
self.noise = bp.dyn.PoissonGroup(num_poisson, freqs=15.)
23-
self.group = bp.dyn.Lif(num_lif, V_reset=-60., V_rest=-74, V_th=-54, tau=10.,
24-
V_initializer=bp.init.Normal(-60., 1.))
23+
self.group = bp.dyn.Lif(
24+
num_lif, V_reset=-60., V_rest=-74, V_th=-54, tau=10.,
25+
V_initializer=bp.init.Normal(-60., 1.)
26+
)
2527

2628
# synapses
2729
syn = bp.dyn.Expon.desc(num_lif, tau=5.)
2830
out = bp.dyn.COBA.desc(E=0.)
29-
comm = bp.dnn.AllToAll(num_poisson, num_lif, bp.init.Uniform(0., g_max))
30-
self.syn = bp.dyn.STDP_Song2000(self.noise, None, syn, comm, out, self.group,
31-
tau_s=20, tau_t=20, W_max=g_max, W_min=0.,
32-
A1=0.01 * g_max, A2=0.0105 * g_max)
31+
comm = bp.dnn.AllToAll(num_poisson, num_lif, bp.init.Uniform(0., g_max), mode=bm.TrainingMode())
32+
self.syn = bp.dyn.STDP_Song2000(
33+
self.noise, None, syn, comm, out, self.group,
34+
tau_s=20, tau_t=20, W_max=g_max, W_min=0.,
35+
A1=0.01 * g_max, A2=0.0105 * g_max
36+
)
3337

3438
def update(self, *args, **kwargs):
3539
self.noise()

examples/dynamics_training/echo_state_network.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import brainpy.math as bm
55

66
bm.set_environment(bm.batching_mode)
7+
bp.share.save(fit=True)
78

89

910
class ESN(bp.DynamicalSystem):

examples/dynamics_training/reservoir-mnist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import brainpy as bp
99
import brainpy.math as bm
1010

11-
traindata = bd.vision.MNIST(root='D:/data', split='train')
12-
testdata = bd.vision.MNIST(root='D:/data', split='test')
11+
traindata = bd.vision.MNIST(root='./data', split='train', download=True)
12+
testdata = bd.vision.MNIST(root='./data', split='test', download=True)
1313

1414

1515
def offline_train(num_hidden=2000, num_in=28, num_out=10):
@@ -18,7 +18,7 @@ def offline_train(num_hidden=2000, num_in=28, num_out=10):
1818
x_train = x_train.reshape(-1, x_train.shape[-1])
1919
y_train = bm.one_hot(jnp.repeat(traindata.targets, x_train.shape[1]), 10, dtype=bm.float_)
2020

21-
reservoir = bp.layers.Reservoir(
21+
reservoir = bp.dyn.Reservoir(
2222
num_in,
2323
num_hidden,
2424
Win_initializer=bp.init.Uniform(-0.6, 0.6),

examples/operator_customization/event_ell.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ numpy
22
brainstate>=0.1.6
33
brainunit
44
brainevent>=0.0.4
5-
braintools>=0.0.7
5+
braintools>=0.0.9
66
jax
77
tqdm

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
packages=packages,
3434
python_requires='>=3.10',
3535
install_requires=[
36-
'numpy>=1.15', 'jax', 'tqdm', 'brainstate>=0.1.6', 'brainunit', 'brainevent>=0.0.4', 'braintools>=0.0.7'
36+
'numpy>=1.15', 'jax', 'tqdm', 'brainstate>=0.1.6', 'brainunit', 'brainevent>=0.0.4', 'braintools>=0.0.9'
3737
],
3838
url='https://github.com/brainpy/BrainPy',
3939
project_urls={

0 commit comments

Comments
 (0)