Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import jax.numpy as jnp
import numpy as np
from jax.scipy.optimize import minimize
from jax.tree_util import tree_flatten, tree_map

import brainpy._src.math as bm
from brainpy import optim, losses
Expand Down Expand Up @@ -265,7 +264,7 @@ def opt_losses(self, val):
@property
def fixed_points(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""The final fixed points found."""
return tree_map(lambda a: np.asarray(a), self._fixed_points)
return jax.tree.map(lambda a: np.asarray(a), self._fixed_points)

@fixed_points.setter
def fixed_points(self, val):
Expand Down Expand Up @@ -339,11 +338,11 @@ def find_fps_with_gd_method(
num_candidate = self._check_candidates(candidates)
if not (isinstance(candidates, (bm.ndarray, jnp.ndarray, np.ndarray)) or isinstance(candidates, dict)):
raise ValueError('Candidates must be instance of ArrayType or dict of ArrayType.')
fixed_points = tree_map(lambda a: bm.TrainVar(a), candidates, is_leaf=lambda x: isinstance(x, bm.BaseArray))
fixed_points = jax.tree.map(lambda a: bm.TrainVar(a), candidates, is_leaf=lambda x: isinstance(x, bm.BaseArray))
f_eval_loss = self._get_f_eval_loss()

def f_loss():
return f_eval_loss(tree_map(lambda a: bm.as_jax(a),
return f_eval_loss(jax.tree.map(lambda a: bm.as_jax(a),
fixed_points,
is_leaf=lambda x: isinstance(x, bm.BaseArray))).mean()

Expand Down Expand Up @@ -387,10 +386,10 @@ def batch_train(start_i, n_batch):
f'is below tolerance {tolerance:0.10f}.')

self._opt_losses = jnp.concatenate(opt_losses)
self._losses = f_eval_loss(tree_map(lambda a: bm.as_jax(a),
self._losses = f_eval_loss(jax.tree.map(lambda a: bm.as_jax(a),
fixed_points,
is_leaf=lambda x: isinstance(x, bm.BaseArray)))
self._fixed_points = tree_map(lambda a: bm.as_jax(a),
self._fixed_points = jax.tree.map(lambda a: bm.as_jax(a),
fixed_points,
is_leaf=lambda x: isinstance(x, bm.BaseArray))
self._selected_ids = jnp.arange(num_candidate)
Expand Down Expand Up @@ -429,7 +428,7 @@ def find_fps_with_opt_solver(
print(f"Optimizing with {opt_solver} to find fixed points:")

# optimizing
res = f_opt(tree_map(lambda a: bm.as_jax(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray)))
res = f_opt(jax.tree.map(lambda a: bm.as_jax(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray)))

# results
valid_ids = jnp.where(res.success)[0]
Expand Down Expand Up @@ -467,7 +466,7 @@ def filter_loss(self, tolerance: float = 1e-5):
num_fps = self._fixed_points.shape[0]
ids = self._losses < tolerance
keep_ids = bm.as_jax(bm.where(ids)[0])
self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points)
self._fixed_points = jax.tree.map(lambda a: a[keep_ids], self._fixed_points)
self._losses = self._losses[keep_ids]
self._selected_ids = self._selected_ids[keep_ids]
if self.verbose:
Expand All @@ -490,7 +489,7 @@ def keep_unique(self, tolerance: float = 2.5e-2):
else:
num_fps = self._fixed_points.shape[0]
fps, keep_ids = utils.keep_unique(self.fixed_points, tolerance=tolerance)
self._fixed_points = tree_map(lambda a: jnp.asarray(a), fps)
self._fixed_points = jax.tree.map(lambda a: jnp.asarray(a), fps)
self._losses = self._losses[keep_ids]
self._selected_ids = self._selected_ids[keep_ids]
if self.verbose:
Expand Down Expand Up @@ -525,7 +524,7 @@ def exclude_outliers(self, tolerance: float = 1e0):

# Return data with outliers removed and indices of kept datapoints.
keep_ids = np.where(closest_neighbor < tolerance)[0]
self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points)
self._fixed_points = jax.tree.map(lambda a: a[keep_ids], self._fixed_points)
self._selected_ids = self._selected_ids[keep_ids]
self._losses = self._losses[keep_ids]

Expand Down Expand Up @@ -562,11 +561,11 @@ def compute_jacobians(
"""
# check data
info = np.asarray([(l.ndim, l.shape[0])
for l in tree_flatten(points, is_leaf=lambda a: isinstance(a, bm.BaseArray))[0]])
for l in jax.tree.flatten(points, is_leaf=lambda a: isinstance(a, bm.BaseArray))[0]])
ndim = np.unique(info[:, 0])
if len(ndim) != 1: raise ValueError(f'Get multiple dimension of the evaluated points. {ndim}')
if ndim[0] == 1:
points = tree_map(lambda a: bm.asarray([a]), points)
points = jax.tree.map(lambda a: bm.asarray([a]), points)
num_point = 1
elif ndim[0] == 2:
nsize = np.unique(info[:, 1])
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/analysis/utils/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import tree_flatten

import brainpy._src.math as bm
from brainpy.tools import numba_jit
Expand Down Expand Up @@ -112,7 +111,7 @@ def euclidean_distance_jax(points: Union[jnp.ndarray, bm.ndarray], num_point=Non
num_point = points.shape[0]
indices = jnp.triu_indices(num_point)
dist_mat = bm.zeros((num_point, num_point))
leaves, _ = tree_flatten(points)
leaves, _ = jax.tree.flatten(points)
dist_mat[indices] = _ed(*indices, leaves)
dist_mat = jnp.maximum(dist_mat.value, dist_mat.value.T)
return dist_mat
2 changes: 1 addition & 1 deletion brainpy/_src/analysis/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):
self.runner = DSRunner(self, monitors=list(initial_vars.keys()), dt=dt, progress_bar=False)

def update(self):
all_vars = list(self.implicit_vars.values())

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Changed to use .value for all_vars; this may break if implicit_vars contains non-Variable objects.

Add a type check or fallback to handle cases where items in implicit_vars lack a .value attribute to prevent runtime errors.

all_vars = [v.value for v in self.implicit_vars.values()]
for key, intg in self.integrals.items():
self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=share['dt']))

Expand Down
5 changes: 2 additions & 3 deletions brainpy/_src/analysis/utils/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import tree_map

import brainpy.math as bm
from .function import f_without_jaxarray_return
Expand Down Expand Up @@ -116,7 +115,7 @@ def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]],
return candidates, keep_ids
if num_fps <= 1:
return candidates, keep_ids

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: tree_map replaced with jax.tree.map; check for compatibility with all input types.

Ensure jax.tree.map correctly processes any custom containers or objects in candidates, particularly regarding the is_leaf parameter.

Suggested implementation:

    # Define a robust is_leaf function to handle bm.BaseArray and any other custom containers
    def _is_leaf(node):
        # Extend this check if you have other custom array-like types
        return isinstance(node, bm.BaseArray) or isinstance(node, np.ndarray) or isinstance(node, jnp.ndarray)
    candidates = jax.tree.map(lambda a: np.asarray(a), candidates, is_leaf=_is_leaf)
    if keep_ids.shape[0] > 0:
        unique_fps = jax.tree.map(lambda a: a[keep_ids], candidates, is_leaf=_is_leaf)
    else:
        unique_fps = np.array([], dtype=dtype)
    return unique_fps, keep_ids

candidates = tree_map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray))
candidates = jax.tree.map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray))

# If point A and point B are within identical_tol of each other, and the
# A is first in the list, we keep A.
Expand All @@ -129,7 +128,7 @@ def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]],
all_drop_idxs += list(drop_idxs)
keep_ids = np.setdiff1d(example_idxs, np.unique(all_drop_idxs))
if keep_ids.shape[0] > 0:
unique_fps = tree_map(lambda a: a[keep_ids], candidates)
unique_fps = jax.tree.map(lambda a: a[keep_ids], candidates)
else:
unique_fps = np.array([], dtype=dtype)
return unique_fps, keep_ids
Expand Down
12 changes: 5 additions & 7 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def stdp_update(
if on_post is not None:
spike = on_post['spike']
trace = on_post['trace']
self.W.value = dense_on_post(self.W.value, spike, trace, w_min, w_max)
self.W.value = dense_on_post(self.W.value, trace, spike, w_min, w_max)


Linear = Dense
Expand Down Expand Up @@ -303,18 +303,16 @@ def stdp_update(
w_min: numbers.Number = None,
w_max: numbers.Number = None
):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.')

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Replace f-string with no interpolated values with string (remove-redundant-fstring)

Suggested change
raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.')
raise ValueError(
'When using STDP to update synaptic weights, the weight must be a variable.'
)

if on_pre is not None:
spike = on_pre['spike']
trace = on_pre['trace']
self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max)
if on_post is not None:
spike = on_post['spike']
trace = on_post['trace']
self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
self.weight.value = dense_on_post(self.weight.value, trace, spike, w_min, w_max)


class OneToOne(Layer, SupportSTDP):
Expand Down Expand Up @@ -442,7 +440,7 @@ def stdp_update(
if on_post is not None:
spike = on_post['spike']
trace = on_post['trace']
self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max)
self.weight.value = dense_on_post(self.weight.value, trace, spike, w_min, w_max)


class _CSRLayer(Layer, SupportSTDP):
Expand Down Expand Up @@ -503,7 +501,7 @@ def stdp_update(
trace = on_post['trace']
self.weight.value = csr2csc_on_post(
self.weight.value, self._pre_ids, self._post_indptr,
self.w_indices, spike, trace, w_min, w_max,
self.w_indices, trace, spike, w_min, w_max,
shape=(trace.shape[0], spike.shape[0]),
)

Expand Down
12 changes: 11 additions & 1 deletion brainpy/_src/integrators/ode/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,10 @@ def integral_func(*args, **kwargs):
for i, parse in enumerate(parses):
f_integral, vars_, pars_ = parse
vps = vars_ + pars_ + [C.DT]
r = f_integral(params_in[vps[0]], **{arg: params_in[arg] for arg in vps[1:] if arg in params_in})
r = f_integral(
_as_value(params_in[vps[0]]),
**{arg: _as_value(params_in[arg]) for arg in vps[1:] if arg in params_in}
)
results.append(r)
return results if len(self.variables) > 1 else results[0]

Expand Down Expand Up @@ -370,3 +373,10 @@ def integral(*args, **kwargs):
register_ode_integrator('exp_euler', ExponentialEuler)
register_ode_integrator('exp_euler_auto', ExponentialEuler)
register_ode_integrator('exp_auto', ExponentialEuler)


def _as_value(x):
if isinstance(x, bm.Variable):
return x.value
else:
return x
Comment on lines +379 to +382

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Replace if statement with if expression (assign-if-exp)

Suggested change
if isinstance(x, bm.Variable):
return x.value
else:
return x
return x.value if isinstance(x, bm.Variable) else x

4 changes: 2 additions & 2 deletions brainpy/_src/math/object_transform/tests/test_controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def f1():
lambda: 4, lambda: 5])

self.assertTrue(f(11) == 1)
print(var_a)
self.assertTrue(bm.all(var_a == 1))
print(var_a.value)
self.assertTrue(bm.all(var_a.value == 1))
self.assertTrue(f(1) == 4)
self.assertTrue(f(-1) == 5)

Expand Down
114 changes: 57 additions & 57 deletions brainpy/_src/math/object_transform/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,63 +349,63 @@ def value(self, v):
self._been_writen = True # set the flag
self._write_value(v) # write the value

def tree_flatten(self):
"""Flattens this variable.

Returns:
A pair where the first element is a list of leaf values
and the second element is a treedef representing the
structure of the flattened tree.
"""
return (self._value,), None

@classmethod
def tree_unflatten(cls, aux_data, flat_contents):
"""Reconstructs a variable from the aux_data and the leaves.

Args:
aux_data:
flat_contents:

Returns:
The variable.
"""
return cls(*flat_contents)

def clone(self) -> 'Variable':
"""Clone the variable. """
r = type(self)(jnp.array(self.value, copy=True), batch_axis=self.batch_axis)
return r

def __eq__(self, other):
"""Override State's __eq__ to use BaseArray behavior for element-wise comparison."""
from brainpy._src.math.ndarray import _check_input_array, _return
return _return(self.value == _check_input_array(other))

def __ne__(self, other):
"""Override State's __ne__ to use BaseArray behavior for element-wise comparison."""
from brainpy._src.math.ndarray import _check_input_array, _return
return _return(self.value != _check_input_array(other))

def __lt__(self, other):
"""Override State's __lt__ to use BaseArray behavior for element-wise comparison."""
from brainpy._src.math.ndarray import _check_input_array, _return
return _return(self.value < _check_input_array(other))

def __le__(self, other):
"""Override State's __le__ to use BaseArray behavior for element-wise comparison."""
from brainpy._src.math.ndarray import _check_input_array, _return
return _return(self.value <= _check_input_array(other))

def __gt__(self, other):
"""Override State's __gt__ to use BaseArray behavior for element-wise comparison."""
from brainpy._src.math.ndarray import _check_input_array, _return
return _return(self.value > _check_input_array(other))

def __ge__(self, other):
"""Override State's __ge__ to use BaseArray behavior for element-wise comparison."""
from brainpy._src.math.ndarray import _check_input_array, _return
return _return(self.value >= _check_input_array(other))
# def tree_flatten(self):
# """Flattens this variable.
#
# Returns:
# A pair where the first element is a list of leaf values
# and the second element is a treedef representing the
# structure of the flattened tree.
# """
# return (self._value,), None
#
# @classmethod
# def tree_unflatten(cls, aux_data, flat_contents):
# """Reconstructs a variable from the aux_data and the leaves.
#
# Args:
# aux_data:
# flat_contents:
#
# Returns:
# The variable.
# """
# return cls(*flat_contents)

# def clone(self) -> 'Variable':
# """Clone the variable. """
# r = type(self)(jnp.array(self.value, copy=True), batch_axis=self.batch_axis)
# return r

# def __eq__(self, other):
# """Override State's __eq__ to use BaseArray behavior for element-wise comparison."""
# from brainpy._src.math.ndarray import _check_input_array, _return
# return _return(self.value == _check_input_array(other))
#
# def __ne__(self, other):
# """Override State's __ne__ to use BaseArray behavior for element-wise comparison."""
# from brainpy._src.math.ndarray import _check_input_array, _return
# return _return(self.value != _check_input_array(other))
#
# def __lt__(self, other):
# """Override State's __lt__ to use BaseArray behavior for element-wise comparison."""
# from brainpy._src.math.ndarray import _check_input_array, _return
# return _return(self.value < _check_input_array(other))
#
# def __le__(self, other):
# """Override State's __le__ to use BaseArray behavior for element-wise comparison."""
# from brainpy._src.math.ndarray import _check_input_array, _return
# return _return(self.value <= _check_input_array(other))
#
# def __gt__(self, other):
# """Override State's __gt__ to use BaseArray behavior for element-wise comparison."""
# from brainpy._src.math.ndarray import _check_input_array, _return
# return _return(self.value > _check_input_array(other))
#
# def __ge__(self, other):
# """Override State's __ge__ to use BaseArray behavior for element-wise comparison."""
# from brainpy._src.math.ndarray import _check_input_array, _return
# return _return(self.value >= _check_input_array(other))


def _get_dtype(v):
Expand Down
2 changes: 2 additions & 0 deletions brainpy/_src/train/back_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def fit(
Please set batch size in your dataset.

"""

share.save(fit=True)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question (bug_risk): Calling share.save(fit=True) at the start of fit may have side effects.

If share.save is not idempotent, repeated or concurrent calls to fit could cause issues. Please review its behavior.

if shared_args is None:
shared_args = dict()
shared_args['fit'] = shared_args.get('fit', True)
Expand Down
Loading
Loading