1010import jax .numpy as jnp
1111import numpy as np
1212from jax .scipy .optimize import minimize
13- from jax .tree_util import tree_flatten , tree_map
1413
1514import brainpy ._src .math as bm
1615from brainpy import optim , losses
@@ -265,7 +264,7 @@ def opt_losses(self, val):
265264 @property
266265 def fixed_points (self ) -> Union [np .ndarray , Dict [str , np .ndarray ]]:
267266 """The final fixed points found."""
268- return tree_map (lambda a : np .asarray (a ), self ._fixed_points )
267+ return jax . tree . map (lambda a : np .asarray (a ), self ._fixed_points )
269268
270269 @fixed_points .setter
271270 def fixed_points (self , val ):
@@ -339,11 +338,11 @@ def find_fps_with_gd_method(
339338 num_candidate = self ._check_candidates (candidates )
340339 if not (isinstance (candidates , (bm .ndarray , jnp .ndarray , np .ndarray )) or isinstance (candidates , dict )):
341340 raise ValueError ('Candidates must be instance of ArrayType or dict of ArrayType.' )
342- fixed_points = tree_map (lambda a : bm .TrainVar (a ), candidates , is_leaf = lambda x : isinstance (x , bm .BaseArray ))
341+ fixed_points = jax . tree . map (lambda a : bm .TrainVar (a ), candidates , is_leaf = lambda x : isinstance (x , bm .BaseArray ))
343342 f_eval_loss = self ._get_f_eval_loss ()
344343
345344 def f_loss ():
346- return f_eval_loss (tree_map (lambda a : bm .as_jax (a ),
345+ return f_eval_loss (jax . tree . map (lambda a : bm .as_jax (a ),
347346 fixed_points ,
348347 is_leaf = lambda x : isinstance (x , bm .BaseArray ))).mean ()
349348
@@ -387,10 +386,10 @@ def batch_train(start_i, n_batch):
387386 f'is below tolerance { tolerance :0.10f} .' )
388387
389388 self ._opt_losses = jnp .concatenate (opt_losses )
390- self ._losses = f_eval_loss (tree_map (lambda a : bm .as_jax (a ),
389+ self ._losses = f_eval_loss (jax . tree . map (lambda a : bm .as_jax (a ),
391390 fixed_points ,
392391 is_leaf = lambda x : isinstance (x , bm .BaseArray )))
393- self ._fixed_points = tree_map (lambda a : bm .as_jax (a ),
392+ self ._fixed_points = jax . tree . map (lambda a : bm .as_jax (a ),
394393 fixed_points ,
395394 is_leaf = lambda x : isinstance (x , bm .BaseArray ))
396395 self ._selected_ids = jnp .arange (num_candidate )
@@ -429,7 +428,7 @@ def find_fps_with_opt_solver(
429428 print (f"Optimizing with { opt_solver } to find fixed points:" )
430429
431430 # optimizing
432- res = f_opt (tree_map (lambda a : bm .as_jax (a ), candidates , is_leaf = lambda a : isinstance (a , bm .BaseArray )))
431+ res = f_opt (jax . tree . map (lambda a : bm .as_jax (a ), candidates , is_leaf = lambda a : isinstance (a , bm .BaseArray )))
433432
434433 # results
435434 valid_ids = jnp .where (res .success )[0 ]
@@ -467,7 +466,7 @@ def filter_loss(self, tolerance: float = 1e-5):
467466 num_fps = self ._fixed_points .shape [0 ]
468467 ids = self ._losses < tolerance
469468 keep_ids = bm .as_jax (bm .where (ids )[0 ])
470- self ._fixed_points = tree_map (lambda a : a [keep_ids ], self ._fixed_points )
469+ self ._fixed_points = jax . tree . map (lambda a : a [keep_ids ], self ._fixed_points )
471470 self ._losses = self ._losses [keep_ids ]
472471 self ._selected_ids = self ._selected_ids [keep_ids ]
473472 if self .verbose :
@@ -490,7 +489,7 @@ def keep_unique(self, tolerance: float = 2.5e-2):
490489 else :
491490 num_fps = self ._fixed_points .shape [0 ]
492491 fps , keep_ids = utils .keep_unique (self .fixed_points , tolerance = tolerance )
493- self ._fixed_points = tree_map (lambda a : jnp .asarray (a ), fps )
492+ self ._fixed_points = jax . tree . map (lambda a : jnp .asarray (a ), fps )
494493 self ._losses = self ._losses [keep_ids ]
495494 self ._selected_ids = self ._selected_ids [keep_ids ]
496495 if self .verbose :
@@ -525,7 +524,7 @@ def exclude_outliers(self, tolerance: float = 1e0):
525524
526525 # Return data with outliers removed and indices of kept datapoints.
527526 keep_ids = np .where (closest_neighbor < tolerance )[0 ]
528- self ._fixed_points = tree_map (lambda a : a [keep_ids ], self ._fixed_points )
527+ self ._fixed_points = jax . tree . map (lambda a : a [keep_ids ], self ._fixed_points )
529528 self ._selected_ids = self ._selected_ids [keep_ids ]
530529 self ._losses = self ._losses [keep_ids ]
531530
@@ -562,11 +561,11 @@ def compute_jacobians(
562561 """
563562 # check data
564563 info = np .asarray ([(l .ndim , l .shape [0 ])
565- for l in tree_flatten (points , is_leaf = lambda a : isinstance (a , bm .BaseArray ))[0 ]])
564+ for l in jax . tree . flatten (points , is_leaf = lambda a : isinstance (a , bm .BaseArray ))[0 ]])
566565 ndim = np .unique (info [:, 0 ])
567566 if len (ndim ) != 1 : raise ValueError (f'Get multiple dimension of the evaluated points. { ndim } ' )
568567 if ndim [0 ] == 1 :
569- points = tree_map (lambda a : bm .asarray ([a ]), points )
568+ points = jax . tree . map (lambda a : bm .asarray ([a ]), points )
570569 num_point = 1
571570 elif ndim [0 ] == 2 :
572571 nsize = np .unique (info [:, 1 ])
0 commit comments