33import functools
44from typing import Union , Optional , Dict , Sequence
55
6+ import jax
67import jax .numpy as jnp
7- from jax .tree_util import tree_flatten , tree_unflatten , tree_map
88
99from brainpy import tools , math as bm
1010from brainpy ._src .context import share
@@ -159,7 +159,7 @@ def __init__(
159159 self .no_state = no_state
160160 self .out_vars = out_vars
161161 if out_vars is not None :
162- out_vars , _ = tree_flatten (out_vars , is_leaf = lambda s : isinstance (s , bm .Variable ))
162+ out_vars , _ = jax . tree . flatten (out_vars , is_leaf = lambda s : isinstance (s , bm .Variable ))
163163 for v in out_vars :
164164 if not isinstance (v , bm .Variable ):
165165 raise TypeError ('out_vars must be a PyTree of Variable.' )
@@ -198,7 +198,7 @@ def __call__(
198198 'Input should be a Array PyTree with the shape '
199199 'of (B, T, ...) or (T, B, ...) with `data_first_axis="T"`, '
200200 'where B the batch size and T the time length.' )
201- xs , tree = tree_flatten (duration_or_xs , lambda a : isinstance (a , bm .BaseArray ))
201+ xs , tree = jax . tree . flatten (duration_or_xs , lambda a : isinstance (a , bm .BaseArray ))
202202 if self .target .mode .is_child_of (bm .BatchingMode ):
203203 b_idx , t_idx = (1 , 0 ) if self .data_first_axis == 'T' else (0 , 1 )
204204
@@ -209,26 +209,26 @@ def __call__(
209209 if len (batch ) != 1 :
210210 raise ValueError ('\n '
211211 'Input should be a Array PyTree with the same batch dimension. '
212- f'but we got { tree_unflatten (tree , batch )} .' )
212+ f'but we got { jax . tree . unflatten (tree , batch )} .' )
213213 try :
214214 length = tuple (set ([x .shape [t_idx ] for x in xs ]))
215215 except (AttributeError , IndexError ) as e :
216216 raise ValueError (inp_err_msg ) from e
217217 if len (batch ) != 1 :
218218 raise ValueError ('\n '
219219 'Input should be a Array PyTree with the same batch size. '
220- f'but we got { tree_unflatten (tree , batch )} .' )
220+ f'but we got { jax . tree . unflatten (tree , batch )} .' )
221221 if len (length ) != 1 :
222222 raise ValueError ('\n '
223223 'Input should be a Array PyTree with the same time length. '
224- f'but we got { tree_unflatten (tree , length )} .' )
224+ f'but we got { jax . tree . unflatten (tree , length )} .' )
225225
226226 if self .no_state :
227227 xs = [bm .reshape (x , (length [0 ] * batch [0 ],) + x .shape [2 :]) for x in xs ]
228228 else :
229229 if self .data_first_axis == 'B' :
230230 xs = [jnp .moveaxis (x , 0 , 1 ) for x in xs ]
231- xs = tree_unflatten (tree , xs )
231+ xs = jax . tree . unflatten (tree , xs )
232232 origin_shape = (length [0 ], batch [0 ]) if self .data_first_axis == 'T' else (batch [0 ], length [0 ])
233233
234234 else :
@@ -240,15 +240,15 @@ def __call__(
240240 if len (length ) != 1 :
241241 raise ValueError ('\n '
242242 'Input should be a Array PyTree with the same time length. '
243- f'but we got { tree_unflatten (tree , length )} .' )
244- xs = tree_unflatten (tree , xs )
243+ f'but we got { jax . tree . unflatten (tree , length )} .' )
244+ xs = jax . tree . unflatten (tree , xs )
245245 origin_shape = (length [0 ],)
246246
247247 # computation
248248 if self .no_state :
249249 share .save (** self .shared_arg )
250250 outputs = self ._run (self .shared_arg , dict (), xs )
251- results = tree_map (lambda a : jnp .reshape (a , origin_shape + a .shape [1 :]), outputs )
251+ results = jax . tree . map (lambda a : jnp .reshape (a , origin_shape + a .shape [1 :]), outputs )
252252 if self .i0 is not None :
253253 self .i0 += length [0 ]
254254 if self .t0 is not None :
@@ -263,6 +263,7 @@ def __call__(
263263 shared ['i' ] = jnp .arange (0 , length [0 ]) + self .i0 .value
264264
265265 assert not self .no_state
266+ xs = jax .tree .map (lambda x : x .value if isinstance (x , bm .Variable ) else x , xs , is_leaf = lambda x : isinstance (x , bm .Variable ))
266267 results = bm .for_loop (functools .partial (self ._run , self .shared_arg ),
267268 (shared , xs ),
268269 jit = self .jit ,
@@ -283,6 +284,6 @@ def _run(self, static_sh, dyn_sh, x):
283284 share .save (** static_sh , ** dyn_sh )
284285 outs = self .target (x )
285286 if self .out_vars is not None :
286- outs = (outs , tree_map (bm .as_jax , self .out_vars ))
287+ outs = (outs , jax . tree . map (bm .as_jax , self .out_vars , is_leaf = lambda x : isinstance ( x , bm . Variable ) ))
287288 clear_input (self .target )
288289 return outs
0 commit comments