Skip to content

Commit ac05b2c

Browse files
committed
Fix spike return values and update file paths in simulation scripts
1 parent 6c6a79e commit ac05b2c

9 files changed

Lines changed: 29 additions & 598 deletions

File tree

brainpy/_src/math/object_transform/variables.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from typing import Optional, Any, List, Callable, Sequence, Union, Dict, Tuple
22

3-
import brainstate
43
import jax
54
import numpy as np
6-
from brainstate._state import record_state_value_read, record_state_value_write
75
from jax import numpy as jnp
86
from jax.dtypes import canonicalize_dtype
97
from jax.tree_util import register_pytree_node_class
108

9+
import brainstate
1110
from brainpy._src.math.ndarray import BaseArray
1211
from brainpy._src.math.sharding import BATCH_AXIS
1312
from brainpy.errors import MathError
13+
from brainstate._state import record_state_value_read, record_state_value_write
1414

1515
__all__ = [
1616
'Variable',
@@ -313,8 +313,12 @@ def batch_size(self) -> Optional[int]:
313313
def batch_size(self, val):
314314
raise ValueError(f'Cannot set "batch_size" manually.')
315315

316+
def _ensure_value_exists(self):
317+
pass
318+
316319
@property
317320
def value(self):
321+
self._ensure_value_exists()
318322
record_state_value_read(self)
319323
return self._read_value()
320324

brainpy/_src/math/random.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -555,16 +555,6 @@ def _ensure_value_exists(self):
555555
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
556556
self._value = seed_or_key
557557

558-
@property
559-
def value(self):
560-
self._ensure_value_exists()
561-
record_state_value_read(self)
562-
return self._read_value()
563-
564-
# def check_if_deleted(self):
565-
# if self.value.is_deleted():
566-
# self.seed()
567-
568558
def split_key(self):
569559
"""Create a new seed from the current seed.
570560
"""

brainpy/_src/transform.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import functools
44
from typing import Union, Optional, Dict, Sequence
55

6+
import jax
67
import jax.numpy as jnp
7-
from jax.tree_util import tree_flatten, tree_unflatten, tree_map
88

99
from brainpy import tools, math as bm
1010
from brainpy._src.context import share
@@ -159,7 +159,7 @@ def __init__(
159159
self.no_state = no_state
160160
self.out_vars = out_vars
161161
if out_vars is not None:
162-
out_vars, _ = tree_flatten(out_vars, is_leaf=lambda s: isinstance(s, bm.Variable))
162+
out_vars, _ = jax.tree.flatten(out_vars, is_leaf=lambda s: isinstance(s, bm.Variable))
163163
for v in out_vars:
164164
if not isinstance(v, bm.Variable):
165165
raise TypeError('out_vars must be a PyTree of Variable.')
@@ -198,7 +198,7 @@ def __call__(
198198
'Input should be a Array PyTree with the shape '
199199
'of (B, T, ...) or (T, B, ...) with `data_first_axis="T"`, '
200200
'where B the batch size and T the time length.')
201-
xs, tree = tree_flatten(duration_or_xs, lambda a: isinstance(a, bm.BaseArray))
201+
xs, tree = jax.tree.flatten(duration_or_xs, lambda a: isinstance(a, bm.BaseArray))
202202
if self.target.mode.is_child_of(bm.BatchingMode):
203203
b_idx, t_idx = (1, 0) if self.data_first_axis == 'T' else (0, 1)
204204

@@ -209,26 +209,26 @@ def __call__(
209209
if len(batch) != 1:
210210
raise ValueError('\n'
211211
'Input should be a Array PyTree with the same batch dimension. '
212-
f'but we got {tree_unflatten(tree, batch)}.')
212+
f'but we got {jax.tree.unflatten(tree, batch)}.')
213213
try:
214214
length = tuple(set([x.shape[t_idx] for x in xs]))
215215
except (AttributeError, IndexError) as e:
216216
raise ValueError(inp_err_msg) from e
217217
if len(batch) != 1:
218218
raise ValueError('\n'
219219
'Input should be a Array PyTree with the same batch size. '
220-
f'but we got {tree_unflatten(tree, batch)}.')
220+
f'but we got {jax.tree.unflatten(tree, batch)}.')
221221
if len(length) != 1:
222222
raise ValueError('\n'
223223
'Input should be a Array PyTree with the same time length. '
224-
f'but we got {tree_unflatten(tree, length)}.')
224+
f'but we got {jax.tree.unflatten(tree, length)}.')
225225

226226
if self.no_state:
227227
xs = [bm.reshape(x, (length[0] * batch[0],) + x.shape[2:]) for x in xs]
228228
else:
229229
if self.data_first_axis == 'B':
230230
xs = [jnp.moveaxis(x, 0, 1) for x in xs]
231-
xs = tree_unflatten(tree, xs)
231+
xs = jax.tree.unflatten(tree, xs)
232232
origin_shape = (length[0], batch[0]) if self.data_first_axis == 'T' else (batch[0], length[0])
233233

234234
else:
@@ -240,15 +240,15 @@ def __call__(
240240
if len(length) != 1:
241241
raise ValueError('\n'
242242
'Input should be a Array PyTree with the same time length. '
243-
f'but we got {tree_unflatten(tree, length)}.')
244-
xs = tree_unflatten(tree, xs)
243+
f'but we got {jax.tree.unflatten(tree, length)}.')
244+
xs = jax.tree.unflatten(tree, xs)
245245
origin_shape = (length[0],)
246246

247247
# computation
248248
if self.no_state:
249249
share.save(**self.shared_arg)
250250
outputs = self._run(self.shared_arg, dict(), xs)
251-
results = tree_map(lambda a: jnp.reshape(a, origin_shape + a.shape[1:]), outputs)
251+
results = jax.tree.map(lambda a: jnp.reshape(a, origin_shape + a.shape[1:]), outputs)
252252
if self.i0 is not None:
253253
self.i0 += length[0]
254254
if self.t0 is not None:
@@ -263,6 +263,7 @@ def __call__(
263263
shared['i'] = jnp.arange(0, length[0]) + self.i0.value
264264

265265
assert not self.no_state
266+
xs = jax.tree.map(lambda x: x.value if isinstance(x, bm.Variable) else x, xs, is_leaf=lambda x: isinstance(x, bm.Variable))
266267
results = bm.for_loop(functools.partial(self._run, self.shared_arg),
267268
(shared, xs),
268269
jit=self.jit,
@@ -283,6 +284,6 @@ def _run(self, static_sh, dyn_sh, x):
283284
share.save(**static_sh, **dyn_sh)
284285
outs = self.target(x)
285286
if self.out_vars is not None:
286-
outs = (outs, tree_map(bm.as_jax, self.out_vars))
287+
outs = (outs, jax.tree.map(bm.as_jax, self.out_vars, is_leaf=lambda x: isinstance(x, bm.Variable)))
287288
clear_input(self.target)
288289
return outs

examples/dynamics_simulation/COBA_parallel.py

Lines changed: 0 additions & 179 deletions
This file was deleted.

examples/dynamics_simulation/Sanda_2021_hippo-tha-cortex-model.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

examples/dynamics_simulation/ei_nets.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def update(self, inp):
7272
self.I2I()
7373
self.E(inp)
7474
self.I(inp)
75-
return self.E.spike
75+
return self.E.spike.value
7676

7777
model = EINet()
7878
indices = bm.arange(1000)
@@ -150,7 +150,7 @@ def update(self, inp):
150150
self.I2I()
151151
self.E(inp)
152152
self.I(inp)
153-
return self.E.spike
153+
return self.E.spike.value
154154

155155
model = EINet()
156156
indices = bm.arange(1000)
@@ -199,7 +199,7 @@ def update(self, inp):
199199
self.I2I()
200200
self.E(inp)
201201
self.I(inp)
202-
return self.E.spike
202+
return self.E.spike.value
203203

204204
model = EINet()
205205
indices = bm.arange(1000)
@@ -241,9 +241,9 @@ def update(self, input):
241241

242242

243243
if __name__ == '__main__':
244-
model1()
245-
# model2()
246-
# model3()
247-
# model4()
248-
# model5()
249-
# vanalla_proj()
244+
# model1()
245+
model2()
246+
model3()
247+
model4()
248+
model5()
249+
vanalla_proj()

0 commit comments

Comments
 (0)