-
Notifications
You must be signed in to change notification settings - Fork 107
Refactor and fix for JAX compatibility, STDP, and training workflows #772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3500290
d35417e
0424b1d
2347423
eabca72
2759dd5
9bbc9fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,6 @@ | |
| import jax | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
| from jax.tree_util import tree_map | ||
|
|
||
| import brainpy.math as bm | ||
| from .function import f_without_jaxarray_return | ||
|
|
@@ -116,7 +115,7 @@ def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]], | |
| return candidates, keep_ids | ||
| if num_fps <= 1: | ||
| return candidates, keep_ids | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: tree_map replaced with jax.tree.map; check for compatibility with all input types. Ensure jax.tree.map correctly processes any custom containers or objects in candidates, particularly regarding the is_leaf parameter. Suggested implementation: # Define a robust is_leaf function to handle bm.BaseArray and any other custom containers
def _is_leaf(node):
# Extend this check if you have other custom array-like types
return isinstance(node, bm.BaseArray) or isinstance(node, np.ndarray) or isinstance(node, jnp.ndarray)
candidates = jax.tree.map(lambda a: np.asarray(a), candidates, is_leaf=_is_leaf) if keep_ids.shape[0] > 0:
unique_fps = jax.tree.map(lambda a: a[keep_ids], candidates, is_leaf=_is_leaf)
else:
unique_fps = np.array([], dtype=dtype)
return unique_fps, keep_ids |
||
| candidates = tree_map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray)) | ||
| candidates = jax.tree.map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray)) | ||
|
|
||
| # If point A and point B are within identical_tol of each other, and the | ||
| # A is first in the list, we keep A. | ||
|
|
@@ -129,7 +128,7 @@ def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]], | |
| all_drop_idxs += list(drop_idxs) | ||
| keep_ids = np.setdiff1d(example_idxs, np.unique(all_drop_idxs)) | ||
| if keep_ids.shape[0] > 0: | ||
| unique_fps = tree_map(lambda a: a[keep_ids], candidates) | ||
| unique_fps = jax.tree.map(lambda a: a[keep_ids], candidates) | ||
| else: | ||
| unique_fps = np.array([], dtype=dtype) | ||
| return unique_fps, keep_ids | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -217,7 +217,7 @@ def stdp_update( | |||||||||
| if on_post is not None: | ||||||||||
| spike = on_post['spike'] | ||||||||||
| trace = on_post['trace'] | ||||||||||
| self.W.value = dense_on_post(self.W.value, spike, trace, w_min, w_max) | ||||||||||
| self.W.value = dense_on_post(self.W.value, trace, spike, w_min, w_max) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| Linear = Dense | ||||||||||
|
|
@@ -303,18 +303,16 @@ def stdp_update( | |||||||||
| w_min: numbers.Number = None, | ||||||||||
| w_max: numbers.Number = None | ||||||||||
| ): | ||||||||||
| if isinstance(self.weight, float): | ||||||||||
| raise ValueError(f'Cannot update the weight of a constant node.') | ||||||||||
| if not isinstance(self.weight, bm.Variable): | ||||||||||
| self.tracing_variable('weight', self.weight, self.weight.shape) | ||||||||||
| raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.') | ||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (code-quality): Replace f-string with no interpolated values with string (
Suggested change
|
||||||||||
| if on_pre is not None: | ||||||||||
| spike = on_pre['spike'] | ||||||||||
| trace = on_pre['trace'] | ||||||||||
| self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max) | ||||||||||
| if on_post is not None: | ||||||||||
| spike = on_post['spike'] | ||||||||||
| trace = on_post['trace'] | ||||||||||
| self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max) | ||||||||||
| self.weight.value = dense_on_post(self.weight.value, trace, spike, w_min, w_max) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class OneToOne(Layer, SupportSTDP): | ||||||||||
|
|
@@ -442,7 +440,7 @@ def stdp_update( | |||||||||
| if on_post is not None: | ||||||||||
| spike = on_post['spike'] | ||||||||||
| trace = on_post['trace'] | ||||||||||
| self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max) | ||||||||||
| self.weight.value = dense_on_post(self.weight.value, trace, spike, w_min, w_max) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class _CSRLayer(Layer, SupportSTDP): | ||||||||||
|
|
@@ -503,7 +501,7 @@ def stdp_update( | |||||||||
| trace = on_post['trace'] | ||||||||||
| self.weight.value = csr2csc_on_post( | ||||||||||
| self.weight.value, self._pre_ids, self._post_indptr, | ||||||||||
| self.w_indices, spike, trace, w_min, w_max, | ||||||||||
| self.w_indices, trace, spike, w_min, w_max, | ||||||||||
| shape=(trace.shape[0], spike.shape[0]), | ||||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -330,7 +330,10 @@ def integral_func(*args, **kwargs): | |||||||||||
| for i, parse in enumerate(parses): | ||||||||||||
| f_integral, vars_, pars_ = parse | ||||||||||||
| vps = vars_ + pars_ + [C.DT] | ||||||||||||
| r = f_integral(params_in[vps[0]], **{arg: params_in[arg] for arg in vps[1:] if arg in params_in}) | ||||||||||||
| r = f_integral( | ||||||||||||
| _as_value(params_in[vps[0]]), | ||||||||||||
| **{arg: _as_value(params_in[arg]) for arg in vps[1:] if arg in params_in} | ||||||||||||
| ) | ||||||||||||
| results.append(r) | ||||||||||||
| return results if len(self.variables) > 1 else results[0] | ||||||||||||
|
|
||||||||||||
|
|
@@ -370,3 +373,10 @@ def integral(*args, **kwargs): | |||||||||||
| register_ode_integrator('exp_euler', ExponentialEuler) | ||||||||||||
| register_ode_integrator('exp_euler_auto', ExponentialEuler) | ||||||||||||
| register_ode_integrator('exp_auto', ExponentialEuler) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def _as_value(x): | ||||||||||||
| if isinstance(x, bm.Variable): | ||||||||||||
| return x.value | ||||||||||||
| else: | ||||||||||||
| return x | ||||||||||||
|
Comment on lines
+379
to
+382
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (code-quality): Replace if statement with if expression (
Suggested change
|
||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -230,6 +230,8 @@ def fit( | |
| Please set batch size in your dataset. | ||
|
|
||
| """ | ||
|
|
||
| share.save(fit=True) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question (bug_risk): Calling share.save(fit=True) at the start of fit may have side effects. If share.save is not idempotent, repeated or concurrent calls to fit could cause issues. Please review its behavior. |
||
| if shared_args is None: | ||
| shared_args = dict() | ||
| shared_args['fit'] = shared_args.get('fit', True) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): Changed to use .value for all_vars; this may break if implicit_vars contains non-Variable objects.
Add a type check or fallback to handle cases where items in implicit_vars lack a .value attribute to prevent runtime errors.